dfir_lang/graph/ops/
assert_eq.rs
1use quote::quote_spanned;
2use syn::parse_quote_spanned;
3
4use super::{
5 OperatorCategory, OperatorConstraints, WriteContextArgs, RANGE_0, RANGE_1,
6};
7
8pub const ASSERT_EQ: OperatorConstraints = OperatorConstraints {
29 name: "assert_eq",
30 categories: &[OperatorCategory::Control],
31 hard_range_inn: RANGE_1,
32 soft_range_inn: RANGE_1,
33 hard_range_out: &(0..=1),
34 soft_range_out: &(0..=1),
35 num_args: 1,
36 persistence_args: RANGE_0,
37 type_args: RANGE_0,
38 is_external_input: false,
39 has_singleton_output: false,
40 flo_type: None,
41 ports_inn: None,
42 ports_out: None,
43 input_delaytype_fn: |_| None,
44 write_fn: |wc @ &WriteContextArgs {
45 context,
46 df_ident,
47 op_span,
48 arguments,
49 ..
50 },
51 diagnostics| {
52 let assert_index_ident = wc.make_ident("assert_index");
53
54 let arg = &arguments[0];
55
56 let inspect_fn = parse_quote_spanned! {op_span=>
57 |item| {
58 fn __constrain_types<T>(array: &impl ::std::ops::Index<usize, Output = T>, index: usize) -> &T {
60 &array[index]
61 }
62
63 unsafe {
64 let index = #context.state_ref_unchecked(#assert_index_ident).get();
66 ::std::assert_eq!(__constrain_types(&#arg, index), item, "Item (right) at index {} does not equal expected (left).", index);
67 #context.state_ref_unchecked(#assert_index_ident).set(index + 1);
68 };
69 }
70 };
71
72 let wc = WriteContextArgs {
73 arguments: &inspect_fn,
74 ..wc.clone()
75 };
76
77 let mut owo = (super::inspect::INSPECT.write_fn)(&wc, diagnostics)?;
78
79 let write_prologue = owo.write_prologue;
80 owo.write_prologue = quote_spanned! {op_span=>
81 let #assert_index_ident = #df_ident.add_state(
82 ::std::cell::Cell::new(0usize)
83 );
84
85 #write_prologue
86 };
87
88 Ok(owo)
89 },
90};