blackbox_log/
predictor.rs

1use core::ops::{Add, Sub};
2
3use super::frame::GpsPosition;
4use crate::Headers;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7#[repr(u16)]
8pub(crate) enum Predictor {
9    Zero = 0,
10    Previous,
11    StraightLine,
12    Average2,
13    MinThrottle,
14    Motor0,
15    Increment,
16    HomeLat,
17    FifteenHundred,
18    VBatReference,
19    LastMainFrameTime,
20    MinMotor,
21    HomeLon = 256,
22}
23
24impl Predictor {
25    pub(crate) fn apply(
26        self,
27        value: u32,
28        signed: bool,
29        current: Option<&[u32]>,
30        ctx: &PredictorContext,
31    ) -> u32 {
32        let _span = if signed {
33            tracing::trace_span!(
34                "Predictor::apply",
35                ?self,
36                value = value.cast_signed(),
37                last = ctx.last.map(u32::cast_signed),
38                last_last = ctx.last_last.map(u32::cast_signed),
39                skipped_frames = ctx.skipped_frames,
40            )
41        } else {
42            tracing::trace_span!(
43                "Predictor::apply",
44                ?self,
45                value,
46                last = ctx.last,
47                last_last = ctx.last_last,
48                skipped_frames = ctx.skipped_frames
49            )
50        };
51        let _span = _span.enter();
52
53        let diff = match self {
54            Self::Zero => 0,
55            Self::Previous => ctx.last.unwrap_or(0),
56            Self::StraightLine => {
57                if signed {
58                    straight_line(
59                        ctx.last.map(u32::cast_signed),
60                        ctx.last_last.map(u32::cast_signed),
61                    )
62                    .cast_unsigned()
63                } else {
64                    straight_line(ctx.last, ctx.last_last)
65                }
66            }
67            Self::Average2 => {
68                if signed {
69                    average(
70                        ctx.last.map(u32::cast_signed),
71                        ctx.last_last.map(u32::cast_signed),
72                    )
73                    .cast_unsigned()
74                } else {
75                    average(ctx.last, ctx.last_last)
76                }
77            }
78            Self::MinThrottle => ctx.headers.min_throttle.unwrap().into(),
79            Self::Motor0 => current.map_or_else(
80                || {
81                    tracing::debug!("found {self:?} without current values");
82                    0
83                },
84                |current| ctx.headers.main_frame_def().get_motor_0_from(current),
85            ),
86            Self::Increment => {
87                if signed {
88                    ctx.skipped_frames
89                        .wrapping_add(1)
90                        .wrapping_add(ctx.last.unwrap_or(0))
91                } else {
92                    let skipped_frames = i32::try_from(ctx.skipped_frames)
93                        .expect("never skip more than i32::MAX frames");
94                    skipped_frames
95                        .wrapping_add(1)
96                        .wrapping_add(ctx.last.unwrap_or(0).cast_signed())
97                        .cast_unsigned()
98                }
99            }
100            Self::HomeLat | Self::HomeLon => ctx.gps_home.map_or_else(
101                || {
102                    tracing::debug!("found {self:?} without gps home");
103                    // TODO: invalidate result
104                    0
105                },
106                |home| {
107                    if self == Self::HomeLat {
108                        home.latitude.cast_unsigned()
109                    } else {
110                        home.longitude.cast_unsigned()
111                    }
112                },
113            ),
114            Self::FifteenHundred => 1500,
115            Self::VBatReference => ctx.headers.vbat_reference.unwrap().into(),
116            Self::LastMainFrameTime => {
117                tracing::debug!("found unhandled {self:?}");
118                0
119            }
120            Self::MinMotor => ctx.headers.motor_output_range.unwrap().min.into(),
121        };
122
123        if signed {
124            let signed = value.cast_signed().wrapping_add(diff.cast_signed());
125            tracing::trace!(return = signed);
126            signed.cast_unsigned()
127        } else {
128            let x = value.wrapping_add(diff);
129            tracing::trace!(return = x);
130            x
131        }
132    }
133
134    pub(crate) fn from_num_str(s: &str) -> Option<Self> {
135        match s {
136            "0" => Some(Self::Zero),
137            "1" => Some(Self::Previous),
138            "2" => Some(Self::StraightLine),
139            "3" => Some(Self::Average2),
140            "4" => Some(Self::MinThrottle),
141            "5" => Some(Self::Motor0),
142            "6" => Some(Self::Increment),
143            "7" => Some(Self::HomeLat), // TODO: check that lat = 0, lon = 1
144            "8" => Some(Self::FifteenHundred),
145            "9" => Some(Self::VBatReference),
146            "10" => Some(Self::LastMainFrameTime),
147            "11" => Some(Self::MinMotor),
148            _ => None,
149        }
150    }
151}
152
153#[derive(Debug, Clone)]
154pub(crate) struct PredictorContext<'a, 'data> {
155    headers: &'a Headers<'data>,
156    last: Option<u32>,
157    last_last: Option<u32>,
158    skipped_frames: u32,
159    gps_home: Option<GpsPosition>,
160}
161
162impl<'a, 'data> PredictorContext<'a, 'data> {
163    pub(crate) const fn new(headers: &'a Headers<'data>) -> Self {
164        Self {
165            headers,
166            last: None,
167            last_last: None,
168            skipped_frames: 0,
169            gps_home: None,
170        }
171    }
172
173    pub(crate) const fn with_skipped(headers: &'a Headers<'data>, skipped_frames: u32) -> Self {
174        Self {
175            headers,
176            last: None,
177            last_last: None,
178            skipped_frames,
179            gps_home: None,
180        }
181    }
182
183    pub(crate) const fn with_home(
184        headers: &'a Headers<'data>,
185        gps_home: Option<GpsPosition>,
186    ) -> Self {
187        Self {
188            headers,
189            last: None,
190            last_last: None,
191            skipped_frames: 0,
192            gps_home,
193        }
194    }
195
196    pub(crate) fn set_last(&mut self, last: Option<u32>) {
197        self.last = last;
198    }
199
200    pub(crate) fn set_last_2(&mut self, last: Option<u32>, last_last: Option<u32>) {
201        self.last = last;
202        self.last_last = last_last;
203    }
204}
205
206#[inline]
207pub(crate) fn straight_line<T>(last: Option<T>, last_last: Option<T>) -> T
208where
209    T: NarrowInteger + Default,
210{
211    match (last, last_last) {
212        (Some(last), Some(last_last)) => {
213            let fallback = last;
214
215            let result = {
216                let last = last.widen();
217                let last_last = last_last.widen();
218                let sum = last + last;
219
220                if let Some(diff) = sum.checked_sub(last_last) {
221                    diff
222                } else {
223                    return fallback;
224                }
225            };
226            T::try_from(result).unwrap_or(fallback)
227        }
228        (Some(last), None) => last,
229        _ => T::default(),
230    }
231}
232
233#[inline]
234fn average<T: Integer + Default>(last: Option<T>, last_last: Option<T>) -> T {
235    let last = last.unwrap_or_default();
236    last_last.map_or(last, |last_last| last.midpoint(last_last))
237}
238
239pub(crate) trait Integer: Copy + Add<Output = Self> + Sub<Output = Self> {
240    fn checked_sub(self, rhs: Self) -> Option<Self>;
241    fn midpoint(self, rhs: Self) -> Self;
242}
243
244pub(crate) trait NarrowInteger: Integer
245where
246    Self: TryFrom<Self::Wide>,
247    Self::Wide: From<Self> + Integer,
248{
249    type Wide;
250
251    fn widen(self) -> Self::Wide {
252        self.into()
253    }
254}
255
256macro_rules! impl_integer {
257    ($($t:ty $(=> $wide:ty)?),* $(,)?) => {$(
258        impl Integer for $t {
259            fn checked_sub(self, rhs: Self) -> Option<Self> {
260                <$t>::checked_sub(self, rhs)
261            }
262
263            fn midpoint(self, rhs: Self) -> Self {
264                <$t>::midpoint(self, rhs)
265            }
266        }
267
268        $(impl NarrowInteger for $t {
269            type Wide = $wide;
270        })?
271    )*};
272}
273
274impl_integer!(u8 => u16, i8 => i16);
275impl_integer!(u16 => u32, i16 => i32);
276impl_integer!(u32 => u64, i32 => i64);
277impl_integer!(u64 => u128, i64 => i128);
278impl_integer!(u128, i128);
279
280#[cfg(test)]
281mod tests {
282    use test_case::case;
283
284    #[case(None, None => 0)]
285    #[case(Some(10), None => 10)]
286    #[case(Some(-2), None => -2)]
287    #[case(Some(12), Some(10) => 14)]
288    #[case(Some(10), Some(12) => 8)]
289    #[case(Some(0), Some(i8::MAX) => -i8::MAX)]
290    #[case(Some(0), Some(i8::MIN) => 0 ; "underflow")]
291    #[case(Some(126), Some(0) => 126 ; "overflow")]
292    fn straight_line_signed(last: Option<i8>, last_last: Option<i8>) -> i8 {
293        super::straight_line(last, last_last)
294    }
295
296    #[case(Some(2),Some(2) => 2)]
297    #[case(Some(12), Some(10) => 14)]
298    #[case(Some(10), Some(12) => 8)]
299    #[case(Some(0), Some(u8::MIN) => 0 ; "underflow")]
300    #[case(Some(u8::MAX - 1), Some(0) => 254 ; "overflow")]
301    #[case(Some(0), Some(u8::MAX) => 0 ; "negative result")]
302    fn straight_line_unsigned(last: Option<u8>, last_last: Option<u8>) -> u8 {
303        super::straight_line(last, last_last)
304    }
305
306    #[case(None, None => 0)]
307    #[case(Some(-1), None => -1)]
308    #[case(Some(2), Some(-1) => 0)]
309    #[case(Some(i32::MAX), Some(1) => 0x4000_0000 ; "overflow")]
310    fn average_signed(last: Option<i32>, last_last: Option<i32>) -> i32 {
311        super::average(last, last_last)
312    }
313
314    #[case(None, None => 0)]
315    #[case(Some(1), None => 1)]
316    #[case(Some(2), Some(10) => 6)]
317    #[case(Some(u32::MAX), Some(1) => 0x8000_0000 ; "overflow")]
318    fn average_unsigned(last: Option<u32>, last_last: Option<u32>) -> u32 {
319        super::average(last, last_last)
320    }
321}