dfir_datalog_core/
join_plan.rs

1use std::collections::btree_map::Entry;
2use std::collections::{BTreeMap, HashMap};
3
4use dfir_lang::diagnostic::{Diagnostic, Level};
5use dfir_lang::graph::FlatGraphBuilder;
6use dfir_lang::parse::Pipeline;
7use proc_macro2::Span;
8use rust_sitter::Spanned;
9use syn::{parse_quote, parse_quote_spanned};
10
11use crate::grammar::datalog::{BoolExpr, BoolOp, ExtractExpr, InputRelationExpr, IntExpr};
12use crate::util::{Counter, repeat_tuple};
13
14/// Captures the tree of joins used to compute contributions from a single rule.
15pub enum JoinPlan<'a> {
16    /// A single relation without any joins, leaves of the tree.
17    /// Second element is whether this is a persisted relation.
18    Source(&'a Spanned<InputRelationExpr>, bool),
19    /// A join between two subtrees.
20    Join(Box<JoinPlan<'a>>, Box<JoinPlan<'a>>),
21    AntiJoin(Box<JoinPlan<'a>>, Box<JoinPlan<'a>>),
22    Predicate(Vec<&'a Spanned<BoolExpr>>, Box<JoinPlan<'a>>),
23    /// A join between some relation and a magic relation that emits values between
24    /// 0 and some value in the input relation (upper-exclusive).
25    MagicNatLt(Box<JoinPlan<'a>>, ExtractExpr, ExtractExpr),
26}
27
28/// Tracks the DFIR node that corresponds to a subtree of a join plan.
29pub struct IntermediateJoinNode {
30    /// The name of the DFIR node that this join outputs to.
31    pub name: syn::Ident,
32    /// If true, the correct dataflow for this node ends in a `persist::<'static>()` operator.
33    pub persisted: bool,
34    /// If this join node outputs data through a `tee()` operator, this is the index to consume the node with.
35    /// (this is only used for cases where we are directly reading a relation)
36    pub tee_idx: Option<isize>,
37    /// A mapping from variables in the rule to the index of the corresponding element in the flattened tuples this node emits.
38    pub variable_mapping: BTreeMap<String, usize>,
39    /// Tuple indices that that correspond to wildcard, unused values.
40    pub wildcard_indices: Vec<usize>,
41    /// The type of the flattened tuples this node emits.
42    pub tuple_type: syn::Type,
43    /// The span corresponding to the original sources resulting in this node.
44    pub span: Span,
45}
46
47enum JoinSide {
48    Left,
49    Right,
50}
51
52impl JoinSide {
53    fn index(&self) -> usize {
54        match self {
55            JoinSide::Left => 0,
56            JoinSide::Right => 1,
57        }
58    }
59}
60
61/// Generates a DFIR pipeline that transforms some input to a join
62/// to emit key-value tuples that can be fed into a join operator.
63fn emit_join_input_pipeline(
64    // The identifiers of the input node that the key should be populated with.
65    identifiers_to_join: &[String],
66    identifiers_to_not_join: &[String],
67    // The DFIR node that is one side of the join.
68    source_expanded: &IntermediateJoinNode,
69    // The DFIR node for the join operator.
70    join_node: &syn::Ident,
71    // Whether this node contributes to the left or right side of the join.
72    join_side: JoinSide,
73    // Whether the pipeline is for an anti-join.
74    anti_join: bool,
75    // The DFIR graph to emit the pipeline to.
76    flat_graph_builder: &mut FlatGraphBuilder,
77) {
78    let hash_keys: Vec<syn::Expr> = identifiers_to_join
79        .iter()
80        .map(|ident| {
81            if let Some(idx) = source_expanded.variable_mapping.get(ident) {
82                let idx_ident = syn::Index::from(*idx);
83                parse_quote!(_v.#idx_ident)
84            } else {
85                panic!("Could not find key that is being joined on: {:?}", ident);
86            }
87        })
88        .collect();
89
90    let not_hash_keys: Vec<syn::Expr> = identifiers_to_not_join
91        .iter()
92        .map(|ident| {
93            if let Some(idx) = source_expanded.variable_mapping.get(ident) {
94                let idx_ident = syn::Index::from(*idx);
95                parse_quote!(_v.#idx_ident)
96            } else {
97                panic!("Could not find key that is being joined on: {:?}", ident);
98            }
99        })
100        .chain(source_expanded.wildcard_indices.iter().map(|idx| {
101            let idx_ident = syn::Index::from(*idx);
102            parse_quote!(_v.#idx_ident)
103        }))
104        .collect();
105
106    let out_index = syn::Index::from(join_side.index());
107
108    let source_name = &source_expanded.name;
109    let source_type = &source_expanded.tuple_type;
110
111    let rhs: Pipeline = if anti_join {
112        match join_side {
113            JoinSide::Left => {
114                parse_quote_spanned!(source_expanded.span=> map(|_v: #source_type| ((#(#hash_keys, )*), (#(#not_hash_keys, )*))) -> [pos] #join_node)
115            }
116            JoinSide::Right => {
117                parse_quote_spanned!(source_expanded.span=> map(|_v: #source_type| (#(#hash_keys, )*)) -> [neg] #join_node)
118            }
119        }
120    } else {
121        parse_quote_spanned!(source_expanded.span=> map(|_v: #source_type| ((#(#hash_keys, )*), (#(#not_hash_keys, )*))) -> [#out_index] #join_node)
122    };
123
124    let rhs = if anti_join && source_expanded.persisted {
125        parse_quote_spanned!(source_expanded.span=> persist::<'static>() -> #rhs)
126    } else {
127        rhs
128    };
129
130    let statement = match source_expanded.tee_idx {
131        Some(i) => {
132            let in_index = syn::LitInt::new(&format!("{}", i), Span::call_site());
133            parse_quote_spanned! {source_expanded.span=> #source_name [#in_index] -> #rhs; }
134        }
135        None => parse_quote_spanned! {source_expanded.span=> #source_name -> #rhs; },
136    };
137
138    flat_graph_builder.add_statement(statement);
139}
140
141/// Given a mapping from variable names to their repeated indices, builds a Rust expression that
142/// tests whether the values at those indices are equal for each variable.
143///
144/// For example, `rel(a, b, a, a, b)` would give us the map `{ "a" => [0, 2, 3], "b" => [1, 4] }`.
145/// Then we would want to generate the code `row.0 == row.2 && row.0 == row.3 && row.1 == row.4`.
146fn build_local_constraint_conditions(constraints: &BTreeMap<String, Vec<usize>>) -> syn::Expr {
147    constraints
148        .values()
149        .flat_map(|indices| {
150            let equal_indices = indices
151                .iter()
152                .map(|i| syn::Index::from(*i))
153                .collect::<Vec<_>>();
154
155            let first_index = &equal_indices[0];
156
157            equal_indices
158                .iter()
159                .skip(1)
160                .map(|i| parse_quote!(row.#first_index == row.#i))
161                .collect::<Vec<_>>()
162        })
163        .reduce(|a: syn::Expr, b| parse_quote!(#a && #b))
164        .unwrap()
165}
166
167fn gen_predicate_value_expr(
168    expr: &IntExpr,
169    variable_mapping: &BTreeMap<String, usize>,
170    diagnostics: &mut Vec<Diagnostic>,
171    get_span: &dyn Fn((usize, usize)) -> Span,
172) -> syn::Expr {
173    crate::gen_value_expr(
174        expr,
175        &mut |ident| {
176            if let Some(col) = variable_mapping.get(&ident.name) {
177                let idx = syn::Index::from(*col);
178                parse_quote_spanned!(get_span(ident.span)=> row.#idx)
179            } else {
180                diagnostics.push(Diagnostic::spanned(
181                    get_span(ident.span),
182                    Level::Error,
183                    format!("Could not resolve reference to variable {}", &ident.name),
184                ));
185                parse_quote!(())
186            }
187        },
188        get_span,
189    )
190}
191
192/// Processes an extract expression to generate a DFIR pipeline that reads the input
193/// data from the IDB/EDB.
194///
195/// `row_width` is the number of elements in the tuples emitted by the **current** pipeline,
196/// with all transformations that have been applied while extracting variable so far. The
197/// `cur_row_offset` specifies the index of the current `ExtractExpr` in that tuple. If it
198/// is `None`, then we are the top-level expression and already have a tuple.
199///
200/// This function returns the number of elements in the tuple that will be emitted by the
201/// extraction of the `ExtractExpr`. So for a single variable, it will return `1`, for a
202/// tuple, it will return sum of the number of elements emitted by its children.
203#[expect(clippy::too_many_arguments, reason = "internal code")]
204fn process_extract(
205    extract: &ExtractExpr,
206    variable_mapping: &mut BTreeMap<String, usize>,
207    local_constraints: &mut BTreeMap<String, Vec<usize>>,
208    wildcard_indices: &mut Vec<usize>,
209    reader_pipeline: &mut Pipeline,
210    row_width: usize,
211    cur_row_offset: Option<usize>, // None if at the root and we are already a tuple
212    rule_span: Span,
213) -> usize {
214    match extract {
215        ExtractExpr::Underscore(_) => {
216            wildcard_indices.push(cur_row_offset.unwrap());
217            1
218        }
219        ExtractExpr::Ident(ident) => {
220            if let Entry::Vacant(e) = variable_mapping.entry(ident.name.clone()) {
221                e.insert(cur_row_offset.unwrap());
222            }
223
224            local_constraints
225                .entry(ident.name.clone())
226                .or_default()
227                .push(cur_row_offset.unwrap());
228
229            1
230        }
231        ExtractExpr::Flatten(_, expr) => {
232            let cur_row_offset = cur_row_offset.unwrap();
233            let tuple_elems_post_flat = (0..row_width)
234                .map(|i| {
235                    if i == cur_row_offset {
236                        parse_quote!(__flattened_element)
237                    } else {
238                        let idx: syn::Index = syn::Index::from(i);
239                        parse_quote!(::std::clone::Clone::clone(&row.#idx))
240                    }
241                })
242                .collect::<Vec<syn::Expr>>();
243
244            let flat_idx = syn::Index::from(cur_row_offset);
245
246            let mut row_types: Vec<syn::Type> = vec![];
247            for _ in 0..row_width {
248                row_types.push(parse_quote!(_));
249            }
250
251            let row_type: syn::Type = parse_quote!((#(#row_types, )*));
252
253            *reader_pipeline = parse_quote_spanned! {rule_span=>
254                #reader_pipeline -> flat_map(|row: #row_type| row.#flat_idx.into_iter().map(move |__flattened_element| (#(#tuple_elems_post_flat, )*)))
255            };
256
257            process_extract(
258                expr,
259                variable_mapping,
260                local_constraints,
261                wildcard_indices,
262                reader_pipeline,
263                row_width,
264                Some(cur_row_offset),
265                rule_span,
266            )
267        }
268        ExtractExpr::Untuple(_, tuple_elems, _) => {
269            let mut new_row_width = if let Some(cur_row_offset) = cur_row_offset {
270                let flat_idx = syn::Index::from(cur_row_offset);
271
272                let tuple_elems_post_flat = (0..row_width)
273                    .flat_map(|i| {
274                        if i == cur_row_offset {
275                            (0..tuple_elems.len())
276                                .map(|tuple_i| {
277                                    let idx: syn::Index = syn::Index::from(tuple_i);
278                                    parse_quote!(row_untuple.#flat_idx.#idx)
279                                })
280                                .collect::<Vec<_>>()
281                        } else {
282                            let idx: syn::Index = syn::Index::from(i);
283                            vec![parse_quote!(row_untuple.#idx)]
284                        }
285                    })
286                    .collect::<Vec<syn::Expr>>();
287
288                let mut row_types: Vec<syn::Type> = vec![];
289                for _ in 0..row_width {
290                    row_types.push(parse_quote!(_));
291                }
292
293                let row_type: syn::Type = parse_quote!((#(#row_types, )*));
294
295                *reader_pipeline = parse_quote_spanned! {rule_span=>
296                    #reader_pipeline -> map(|row_untuple: #row_type| (#(#tuple_elems_post_flat, )*))
297                };
298
299                row_width - 1 + tuple_elems.len()
300            } else {
301                row_width
302            };
303
304            let base_offset = cur_row_offset.unwrap_or_default();
305            let mut expanded_row_elements = 0;
306            for expr in tuple_elems {
307                let expanded_width = process_extract(
308                    expr,
309                    variable_mapping,
310                    local_constraints,
311                    wildcard_indices,
312                    reader_pipeline,
313                    new_row_width,
314                    Some(base_offset + expanded_row_elements),
315                    rule_span,
316                );
317
318                // as we process each child of the tuple, the prefix of the
319                // tuple emitted by the pipeline will grow, so we need to update
320                // our cursor and the current overall width appropriately
321                expanded_row_elements += expanded_width;
322                new_row_width = new_row_width - 1 + expanded_width;
323            }
324
325            expanded_row_elements
326        }
327    }
328}
329
330/// Generates a DFIR pipeline that computes the output to a given [`JoinPlan`].
331pub fn expand_join_plan(
332    // The plan we are converting to a DFIR pipeline.
333    plan: &JoinPlan,
334    // The DFIR graph to emit the pipeline to.
335    flat_graph_builder: &mut FlatGraphBuilder,
336    tee_counter: &mut HashMap<String, Counter>,
337    next_join_idx: &mut Counter,
338    rule_span: (usize, usize),
339    diagnostics: &mut Vec<Diagnostic>,
340    get_span: &impl Fn((usize, usize)) -> Span,
341) -> IntermediateJoinNode {
342    match plan {
343        JoinPlan::Source(target, persisted) => {
344            // Because this is a node corresponding to some Datalog relation, we need to tee from it.
345            let tee_index = tee_counter
346                .entry(target.name.name.clone())
347                .or_insert_with(|| 0..)
348                .next()
349                .expect("Out of tee indices");
350
351            let relation_node = syn::Ident::new(&target.name.name, get_span(target.name.span));
352            let relation_idx = syn::LitInt::new(&tee_index.to_string(), Span::call_site());
353
354            let source_node = syn::Ident::new(
355                &format!(
356                    "source_reader_{}",
357                    next_join_idx.next().expect("Out of join indices")
358                ),
359                Span::call_site(),
360            );
361
362            let mut variable_mapping = BTreeMap::new();
363            let mut local_constraints = BTreeMap::new();
364            let mut wildcard_indices = vec![];
365
366            let mut pipeline: Pipeline = parse_quote_spanned! {get_span(rule_span)=>
367                #relation_node [#relation_idx]
368            };
369
370            let final_row_width = process_extract(
371                &ExtractExpr::Untuple((), target.fields.clone(), ()),
372                &mut variable_mapping,
373                &mut local_constraints,
374                &mut wildcard_indices,
375                &mut pipeline,
376                target.fields.len(),
377                None,
378                get_span(rule_span),
379            );
380
381            let mut row_types: Vec<syn::Type> = vec![];
382            for _ in 0..final_row_width {
383                row_types.push(parse_quote!(_));
384            }
385
386            let row_type = parse_quote!((#(#row_types, )*));
387
388            if local_constraints.values().any(|v| v.len() > 1) {
389                let conditions = build_local_constraint_conditions(&local_constraints);
390
391                pipeline = parse_quote_spanned! {get_span(rule_span)=>
392                    #pipeline -> filter(|row: &#row_type| #conditions)
393                };
394            }
395
396            flat_graph_builder.add_statement(parse_quote_spanned! {get_span(rule_span)=>
397                #source_node = #pipeline;
398            });
399
400            IntermediateJoinNode {
401                name: source_node,
402                persisted: *persisted,
403                tee_idx: None,
404                variable_mapping,
405                wildcard_indices,
406                tuple_type: row_type,
407                span: get_span(target.span),
408            }
409        }
410        JoinPlan::Join(lhs, rhs) | JoinPlan::AntiJoin(lhs, rhs) => {
411            let is_anti = matches!(plan, JoinPlan::AntiJoin(_, _));
412
413            let left_expanded = expand_join_plan(
414                lhs,
415                flat_graph_builder,
416                tee_counter,
417                next_join_idx,
418                rule_span,
419                diagnostics,
420                get_span,
421            );
422            let right_expanded = expand_join_plan(
423                rhs,
424                flat_graph_builder,
425                tee_counter,
426                next_join_idx,
427                rule_span,
428                diagnostics,
429                get_span,
430            );
431
432            let identifiers_to_join = right_expanded
433                .variable_mapping
434                .keys()
435                .filter(|i| left_expanded.variable_mapping.contains_key(*i))
436                .enumerate()
437                .map(|t| (t.1, t.0))
438                .collect::<BTreeMap<_, _>>();
439
440            let left_non_joined_identifiers = left_expanded
441                .variable_mapping
442                .keys()
443                .filter(|i| !right_expanded.variable_mapping.contains_key(*i))
444                .enumerate()
445                .map(|t| (t.1, t.0))
446                .collect::<BTreeMap<_, _>>();
447
448            let right_non_joined_identifiers = right_expanded
449                .variable_mapping
450                .keys()
451                .filter(|i| !left_expanded.variable_mapping.contains_key(*i))
452                .enumerate()
453                .map(|t| (t.1, t.0))
454                .collect::<BTreeMap<_, _>>();
455
456            let key_type =
457                repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(_), identifiers_to_join.len());
458
459            let left_type = repeat_tuple::<syn::Type, syn::Type>(
460                || parse_quote!(_),
461                left_non_joined_identifiers.len() + left_expanded.wildcard_indices.len(),
462            );
463            let right_type = repeat_tuple::<syn::Type, syn::Type>(
464                || parse_quote!(_),
465                right_non_joined_identifiers.len() + right_expanded.wildcard_indices.len(),
466            );
467
468            let join_node = syn::Ident::new(
469                &format!(
470                    "join_{}",
471                    next_join_idx.next().expect("Out of join indices")
472                ),
473                Span::call_site(),
474            );
475
476            // We start by defining the pipeline from the `join()` operator onwards. The main logic
477            // here is to flatten the tuples from the left and right sides of the join into a
478            // single tuple that is used by downstream joins or the final output.
479            let mut flattened_tuple_elems: Vec<syn::Expr> = vec![];
480            let mut flattened_mapping = BTreeMap::new();
481            let mut flattened_wildcard_indices = vec![];
482
483            for (ident, idx) in &identifiers_to_join {
484                if !flattened_mapping.contains_key(*ident) {
485                    let idx = syn::Index::from(*idx);
486                    let value_expr: syn::Expr = parse_quote!(kv.0.#idx);
487
488                    flattened_mapping.insert((*ident).clone(), flattened_tuple_elems.len());
489                    flattened_tuple_elems.push(value_expr);
490                }
491            }
492
493            if is_anti {
494                for (ident, idx) in &left_non_joined_identifiers {
495                    if !flattened_mapping.contains_key(*ident) {
496                        let idx = syn::Index::from(*idx);
497                        let value_expr: syn::Expr = parse_quote!(kv.1.#idx);
498
499                        flattened_mapping.insert((*ident).clone(), flattened_tuple_elems.len());
500                        flattened_tuple_elems.push(value_expr);
501                    }
502                }
503
504                for idx in 0..left_expanded.wildcard_indices.len() {
505                    let idx = syn::Index::from(left_non_joined_identifiers.len() + idx);
506                    let value_expr: syn::Expr = parse_quote!(kv.1.#idx);
507
508                    flattened_wildcard_indices.push(flattened_tuple_elems.len());
509                    flattened_tuple_elems.push(value_expr);
510                }
511            } else {
512                for (ident, source_idx) in left_non_joined_identifiers
513                    .keys()
514                    .map(|l| (l, 0))
515                    .chain(right_non_joined_identifiers.keys().map(|l| (l, 1)))
516                {
517                    if !flattened_mapping.contains_key(*ident) {
518                        let syn_source_index = syn::Index::from(source_idx);
519                        let source_expr: syn::Expr = parse_quote!(kv.1.#syn_source_index);
520                        let bindings = if source_idx == 0 {
521                            &left_non_joined_identifiers
522                        } else {
523                            &right_non_joined_identifiers
524                        };
525
526                        let source_col_idx = syn::Index::from(*bindings.get(ident).unwrap());
527
528                        flattened_mapping.insert((*ident).clone(), flattened_tuple_elems.len());
529                        flattened_tuple_elems.push(parse_quote!(#source_expr.#source_col_idx));
530                    }
531                }
532
533                for (idx, source_idx) in (0..left_expanded.wildcard_indices.len())
534                    .map(|i| (left_non_joined_identifiers.len() + i, 0))
535                    .chain(
536                        (0..right_expanded.wildcard_indices.len())
537                            .map(|i| (right_non_joined_identifiers.len() + i, 1)),
538                    )
539                {
540                    let syn_source_index = syn::Index::from(source_idx);
541                    let idx = syn::Index::from(idx);
542                    let value_expr: syn::Expr = parse_quote!(kv.1.#syn_source_index.#idx);
543
544                    flattened_wildcard_indices.push(flattened_tuple_elems.len());
545                    flattened_tuple_elems.push(value_expr);
546                }
547            }
548
549            let flatten_closure: syn::Expr = if is_anti {
550                parse_quote!(|kv: (#key_type, #left_type)| (#(#flattened_tuple_elems, )*))
551            } else {
552                parse_quote!(|kv: (#key_type, (#left_type, #right_type))| (#(#flattened_tuple_elems, )*))
553            };
554
555            let (lt_left, lt_right, is_persist): (syn::Lifetime, syn::Lifetime, bool) =
556                match (left_expanded.persisted, right_expanded.persisted, is_anti) {
557                    (true, false, false) => (parse_quote!('static), parse_quote!('tick), false),
558                    (false, true, false) => (parse_quote!('tick), parse_quote!('static), false),
559                    (true, true, false) => (parse_quote!('static), parse_quote!('static), true),
560                    _ => (parse_quote!('tick), parse_quote!('tick), false),
561                };
562
563            if is_anti {
564                // this is always a 'tick join, so we place a persist operator in the join input pipeline
565                flat_graph_builder.add_statement(parse_quote_spanned! {get_span(rule_span)=>
566                    #join_node = anti_join() -> map(#flatten_closure);
567                });
568            } else {
569                flat_graph_builder.add_statement(
570                    parse_quote_spanned! {get_span(rule_span)=>
571                        #join_node = join::<#lt_left, #lt_right, dfir_rs::compiled::pull::HalfMultisetJoinState>() -> map(#flatten_closure);
572                    }
573                );
574            }
575
576            let output_type = repeat_tuple::<syn::Type, syn::Type>(
577                || parse_quote!(_),
578                flattened_tuple_elems.len(),
579            );
580
581            let intermediate = IntermediateJoinNode {
582                name: join_node.clone(),
583                persisted: is_persist,
584                tee_idx: None,
585                variable_mapping: flattened_mapping,
586                wildcard_indices: flattened_wildcard_indices,
587                tuple_type: output_type,
588                span: left_expanded
589                    .span
590                    .join(right_expanded.span)
591                    .unwrap_or(get_span(rule_span)),
592            };
593
594            emit_join_input_pipeline(
595                &identifiers_to_join
596                    .keys()
597                    .cloned()
598                    .cloned()
599                    .collect::<Vec<_>>(),
600                &left_non_joined_identifiers
601                    .keys()
602                    .cloned()
603                    .cloned()
604                    .collect::<Vec<_>>(),
605                &left_expanded,
606                &join_node,
607                JoinSide::Left,
608                is_anti,
609                flat_graph_builder,
610            );
611
612            emit_join_input_pipeline(
613                &identifiers_to_join
614                    .keys()
615                    .cloned()
616                    .cloned()
617                    .collect::<Vec<_>>(),
618                &right_non_joined_identifiers
619                    .keys()
620                    .cloned()
621                    .cloned()
622                    .collect::<Vec<_>>(),
623                &right_expanded,
624                &join_node,
625                JoinSide::Right,
626                is_anti,
627                flat_graph_builder,
628            );
629
630            intermediate
631        }
632        JoinPlan::Predicate(predicates, inner) => {
633            let inner_expanded = expand_join_plan(
634                inner,
635                flat_graph_builder,
636                tee_counter,
637                next_join_idx,
638                rule_span,
639                diagnostics,
640                get_span,
641            );
642            let inner_name = inner_expanded.name.clone();
643            let row_type = inner_expanded.tuple_type;
644            let variable_mapping = &inner_expanded.variable_mapping;
645
646            let conditions = predicates
647                .iter()
648                .map(|p| {
649                    let l =
650                        gen_predicate_value_expr(&p.left, variable_mapping, diagnostics, get_span);
651                    let r =
652                        gen_predicate_value_expr(&p.right, variable_mapping, diagnostics, get_span);
653
654                    match &p.op {
655                        BoolOp::Lt(_) => parse_quote_spanned!(get_span(p.span)=> #l < #r),
656                        BoolOp::LtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l <= #r),
657                        BoolOp::Gt(_) => parse_quote_spanned!(get_span(p.span)=> #l > #r),
658                        BoolOp::GtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l >= #r),
659                        BoolOp::Eq(_) => parse_quote_spanned!(get_span(p.span)=> #l == #r),
660                        BoolOp::Neq(_) => parse_quote_spanned!(get_span(p.span)=> #l != #r),
661                    }
662                })
663                .reduce(|a: syn::Expr, b| parse_quote!(#a && #b))
664                .unwrap();
665
666            let predicate_filter_node = syn::Ident::new(
667                &format!(
668                    "predicate_{}_filter",
669                    next_join_idx.next().expect("Out of join indices")
670                ),
671                Span::call_site(),
672            );
673
674            flat_graph_builder.add_statement(parse_quote_spanned! { get_span(rule_span)=>
675                #predicate_filter_node = #inner_name -> filter(|row: &#row_type| #conditions );
676            });
677
678            IntermediateJoinNode {
679                name: predicate_filter_node,
680                persisted: inner_expanded.persisted,
681                tee_idx: None,
682                variable_mapping: inner_expanded.variable_mapping,
683                wildcard_indices: inner_expanded.wildcard_indices,
684                tuple_type: row_type,
685                span: get_span(rule_span),
686            }
687        }
688        JoinPlan::MagicNatLt(inner, less_than, threshold) => {
689            let magic_node = syn::Ident::new(
690                &format!(
691                    "magic_nat_lt_{}",
692                    next_join_idx.next().expect("Out of join indices")
693                ),
694                Span::call_site(),
695            );
696
697            let inner_expanded = expand_join_plan(
698                inner,
699                flat_graph_builder,
700                tee_counter,
701                next_join_idx,
702                rule_span,
703                diagnostics,
704                get_span,
705            );
706            let inner_name = inner_expanded.name.clone();
707            let row_type = inner_expanded.tuple_type;
708
709            match &less_than {
710                ExtractExpr::Ident(ident) => {
711                    if inner_expanded.variable_mapping.contains_key(&ident.name) {
712                        todo!(
713                            "The values generated by less_than cannot currently be used in other parts of the query"
714                        );
715                    }
716                }
717                ExtractExpr::Underscore(_) => {}
718                _ => panic!("The values generated by less_than must be a single variable"),
719            }
720
721            let threshold_name = if let ExtractExpr::Ident(threshold) = threshold {
722                threshold.name.clone()
723            } else {
724                panic!("The threshold must be a variable")
725            };
726
727            let threshold_index = inner_expanded
728                .variable_mapping
729                .get(&threshold_name)
730                .expect("Threshold variable not found in inner plan");
731            let threshold_index = syn::Index::from(*threshold_index);
732
733            let mut flattened_elements: Vec<syn::Expr> = vec![];
734            let mut flattened_mapping = BTreeMap::new();
735            let mut flattened_wildcard_indices = Vec::new();
736
737            for (ident, source_idx) in &inner_expanded.variable_mapping {
738                let syn_source_index = syn::Index::from(*source_idx);
739                flattened_mapping.insert(ident.clone(), flattened_elements.len());
740                flattened_elements.push(parse_quote!(row.#syn_source_index.clone()));
741            }
742
743            for wildcard_idx in &inner_expanded.wildcard_indices {
744                let syn_wildcard_idx = syn::Index::from(*wildcard_idx);
745                flattened_wildcard_indices.push(flattened_elements.len());
746                flattened_elements.push(parse_quote!(row.#syn_wildcard_idx.clone()));
747            }
748
749            if let ExtractExpr::Ident(less_than) = less_than {
750                if less_than.name == threshold_name {
751                    panic!("The threshold and less_than variables must be different")
752                }
753
754                flattened_mapping.insert(less_than.name.clone(), flattened_elements.len());
755            } else {
756                flattened_wildcard_indices.push(flattened_elements.len());
757            }
758
759            flattened_elements.push(parse_quote!(v));
760
761            flat_graph_builder.add_statement(parse_quote_spanned! {get_span(rule_span)=>
762                #magic_node = #inner_name -> flat_map(|row: #row_type| (0..(row.#threshold_index)).map(move |v| (#(#flattened_elements, )*)) );
763            });
764
765            IntermediateJoinNode {
766                name: magic_node,
767                persisted: inner_expanded.persisted,
768                tee_idx: None,
769                variable_mapping: flattened_mapping,
770                wildcard_indices: flattened_wildcard_indices,
771                tuple_type: repeat_tuple::<syn::Type, syn::Type>(
772                    || parse_quote!(_),
773                    flattened_elements.len(),
774                ),
775                span: get_span(rule_span),
776            }
777        }
778    }
779}