1use core::ops::{Add, Sub};
2
3use super::frame::GpsPosition;
4use crate::Headers;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7#[repr(u16)]
8pub(crate) enum Predictor {
9 Zero = 0,
10 Previous,
11 StraightLine,
12 Average2,
13 MinThrottle,
14 Motor0,
15 Increment,
16 HomeLat,
17 FifteenHundred,
18 VBatReference,
19 LastMainFrameTime,
20 MinMotor,
21 HomeLon = 256,
22}
23
24impl Predictor {
25 pub(crate) fn apply(
26 self,
27 value: u32,
28 signed: bool,
29 current: Option<&[u32]>,
30 ctx: &PredictorContext,
31 ) -> u32 {
32 let _span = if signed {
33 tracing::trace_span!(
34 "Predictor::apply",
35 ?self,
36 value = value.cast_signed(),
37 last = ctx.last.map(u32::cast_signed),
38 last_last = ctx.last_last.map(u32::cast_signed),
39 skipped_frames = ctx.skipped_frames,
40 )
41 } else {
42 tracing::trace_span!(
43 "Predictor::apply",
44 ?self,
45 value,
46 last = ctx.last,
47 last_last = ctx.last_last,
48 skipped_frames = ctx.skipped_frames
49 )
50 };
51 let _span = _span.enter();
52
53 let diff = match self {
54 Self::Zero => 0,
55 Self::Previous => ctx.last.unwrap_or(0),
56 Self::StraightLine => {
57 if signed {
58 straight_line(
59 ctx.last.map(u32::cast_signed),
60 ctx.last_last.map(u32::cast_signed),
61 )
62 .cast_unsigned()
63 } else {
64 straight_line(ctx.last, ctx.last_last)
65 }
66 }
67 Self::Average2 => {
68 if signed {
69 average(
70 ctx.last.map(u32::cast_signed),
71 ctx.last_last.map(u32::cast_signed),
72 )
73 .cast_unsigned()
74 } else {
75 average(ctx.last, ctx.last_last)
76 }
77 }
78 Self::MinThrottle => ctx.headers.min_throttle.unwrap().into(),
79 Self::Motor0 => current.map_or_else(
80 || {
81 tracing::debug!("found {self:?} without current values");
82 0
83 },
84 |current| ctx.headers.main_frame_def().get_motor_0_from(current),
85 ),
86 Self::Increment => {
87 if signed {
88 ctx.skipped_frames
89 .wrapping_add(1)
90 .wrapping_add(ctx.last.unwrap_or(0))
91 } else {
92 let skipped_frames = i32::try_from(ctx.skipped_frames)
93 .expect("never skip more than i32::MAX frames");
94 skipped_frames
95 .wrapping_add(1)
96 .wrapping_add(ctx.last.unwrap_or(0).cast_signed())
97 .cast_unsigned()
98 }
99 }
100 Self::HomeLat | Self::HomeLon => ctx.gps_home.map_or_else(
101 || {
102 tracing::debug!("found {self:?} without gps home");
103 0
105 },
106 |home| {
107 if self == Self::HomeLat {
108 home.latitude.cast_unsigned()
109 } else {
110 home.longitude.cast_unsigned()
111 }
112 },
113 ),
114 Self::FifteenHundred => 1500,
115 Self::VBatReference => ctx.headers.vbat_reference.unwrap().into(),
116 Self::LastMainFrameTime => {
117 tracing::debug!("found unhandled {self:?}");
118 0
119 }
120 Self::MinMotor => ctx.headers.motor_output_range.unwrap().min.into(),
121 };
122
123 if signed {
124 let signed = value.cast_signed().wrapping_add(diff.cast_signed());
125 tracing::trace!(return = signed);
126 signed.cast_unsigned()
127 } else {
128 let x = value.wrapping_add(diff);
129 tracing::trace!(return = x);
130 x
131 }
132 }
133
134 pub(crate) fn from_num_str(s: &str) -> Option<Self> {
135 match s {
136 "0" => Some(Self::Zero),
137 "1" => Some(Self::Previous),
138 "2" => Some(Self::StraightLine),
139 "3" => Some(Self::Average2),
140 "4" => Some(Self::MinThrottle),
141 "5" => Some(Self::Motor0),
142 "6" => Some(Self::Increment),
143 "7" => Some(Self::HomeLat), "8" => Some(Self::FifteenHundred),
145 "9" => Some(Self::VBatReference),
146 "10" => Some(Self::LastMainFrameTime),
147 "11" => Some(Self::MinMotor),
148 _ => None,
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
154pub(crate) struct PredictorContext<'a, 'data> {
155 headers: &'a Headers<'data>,
156 last: Option<u32>,
157 last_last: Option<u32>,
158 skipped_frames: u32,
159 gps_home: Option<GpsPosition>,
160}
161
162impl<'a, 'data> PredictorContext<'a, 'data> {
163 pub(crate) const fn new(headers: &'a Headers<'data>) -> Self {
164 Self {
165 headers,
166 last: None,
167 last_last: None,
168 skipped_frames: 0,
169 gps_home: None,
170 }
171 }
172
173 pub(crate) const fn with_skipped(headers: &'a Headers<'data>, skipped_frames: u32) -> Self {
174 Self {
175 headers,
176 last: None,
177 last_last: None,
178 skipped_frames,
179 gps_home: None,
180 }
181 }
182
183 pub(crate) const fn with_home(
184 headers: &'a Headers<'data>,
185 gps_home: Option<GpsPosition>,
186 ) -> Self {
187 Self {
188 headers,
189 last: None,
190 last_last: None,
191 skipped_frames: 0,
192 gps_home,
193 }
194 }
195
196 pub(crate) fn set_last(&mut self, last: Option<u32>) {
197 self.last = last;
198 }
199
200 pub(crate) fn set_last_2(&mut self, last: Option<u32>, last_last: Option<u32>) {
201 self.last = last;
202 self.last_last = last_last;
203 }
204}
205
206#[inline]
207pub(crate) fn straight_line<T>(last: Option<T>, last_last: Option<T>) -> T
208where
209 T: NarrowInteger + Default,
210{
211 match (last, last_last) {
212 (Some(last), Some(last_last)) => {
213 let fallback = last;
214
215 let result = {
216 let last = last.widen();
217 let last_last = last_last.widen();
218 let sum = last + last;
219
220 if let Some(diff) = sum.checked_sub(last_last) {
221 diff
222 } else {
223 return fallback;
224 }
225 };
226 T::try_from(result).unwrap_or(fallback)
227 }
228 (Some(last), None) => last,
229 _ => T::default(),
230 }
231}
232
233#[inline]
234fn average<T: Integer + Default>(last: Option<T>, last_last: Option<T>) -> T {
235 let last = last.unwrap_or_default();
236 last_last.map_or(last, |last_last| last.midpoint(last_last))
237}
238
239pub(crate) trait Integer: Copy + Add<Output = Self> + Sub<Output = Self> {
240 fn checked_sub(self, rhs: Self) -> Option<Self>;
241 fn midpoint(self, rhs: Self) -> Self;
242}
243
244pub(crate) trait NarrowInteger: Integer
245where
246 Self: TryFrom<Self::Wide>,
247 Self::Wide: From<Self> + Integer,
248{
249 type Wide;
250
251 fn widen(self) -> Self::Wide {
252 self.into()
253 }
254}
255
256macro_rules! impl_integer {
257 ($($t:ty $(=> $wide:ty)?),* $(,)?) => {$(
258 impl Integer for $t {
259 fn checked_sub(self, rhs: Self) -> Option<Self> {
260 <$t>::checked_sub(self, rhs)
261 }
262
263 fn midpoint(self, rhs: Self) -> Self {
264 <$t>::midpoint(self, rhs)
265 }
266 }
267
268 $(impl NarrowInteger for $t {
269 type Wide = $wide;
270 })?
271 )*};
272}
273
274impl_integer!(u8 => u16, i8 => i16);
275impl_integer!(u16 => u32, i16 => i32);
276impl_integer!(u32 => u64, i32 => i64);
277impl_integer!(u64 => u128, i64 => i128);
278impl_integer!(u128, i128);
279
280#[cfg(test)]
281mod tests {
282 use test_case::case;
283
284 #[case(None, None => 0)]
285 #[case(Some(10), None => 10)]
286 #[case(Some(-2), None => -2)]
287 #[case(Some(12), Some(10) => 14)]
288 #[case(Some(10), Some(12) => 8)]
289 #[case(Some(0), Some(i8::MAX) => -i8::MAX)]
290 #[case(Some(0), Some(i8::MIN) => 0 ; "underflow")]
291 #[case(Some(126), Some(0) => 126 ; "overflow")]
292 fn straight_line_signed(last: Option<i8>, last_last: Option<i8>) -> i8 {
293 super::straight_line(last, last_last)
294 }
295
296 #[case(Some(2),Some(2) => 2)]
297 #[case(Some(12), Some(10) => 14)]
298 #[case(Some(10), Some(12) => 8)]
299 #[case(Some(0), Some(u8::MIN) => 0 ; "underflow")]
300 #[case(Some(u8::MAX - 1), Some(0) => 254 ; "overflow")]
301 #[case(Some(0), Some(u8::MAX) => 0 ; "negative result")]
302 fn straight_line_unsigned(last: Option<u8>, last_last: Option<u8>) -> u8 {
303 super::straight_line(last, last_last)
304 }
305
306 #[case(None, None => 0)]
307 #[case(Some(-1), None => -1)]
308 #[case(Some(2), Some(-1) => 0)]
309 #[case(Some(i32::MAX), Some(1) => 0x4000_0000 ; "overflow")]
310 fn average_signed(last: Option<i32>, last_last: Option<i32>) -> i32 {
311 super::average(last, last_last)
312 }
313
314 #[case(None, None => 0)]
315 #[case(Some(1), None => 1)]
316 #[case(Some(2), Some(10) => 6)]
317 #[case(Some(u32::MAX), Some(1) => 0x8000_0000 ; "overflow")]
318 fn average_unsigned(last: Option<u32>, last_last: Option<u32>) -> u32 {
319 super::average(last, last_last)
320 }
321}