dfir_datalog_core/
lib.rs

1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3
4pub use dfir_lang::diagnostic;
5use dfir_lang::diagnostic::{Diagnostic, Level};
6use dfir_lang::graph::{DfirGraph, FlatGraphBuilder, eliminate_extra_unions_tees, partition_graph};
7use dfir_lang::parse::{
8    DfirStatement, IndexInt, Indexing, Pipeline, PipelineLink, PipelineStatement, PortIndex,
9};
10use proc_macro2::{Span, TokenStream};
11use rust_sitter::errors::{ParseError, ParseErrorReason};
12use syn::{Token, parse_quote, parse_quote_spanned};
13
14mod grammar;
15mod join_plan;
16mod util;
17
18use grammar::datalog::{
19    Aggregation, Atom, Declaration, Ident, IntExpr, Program, Rule, RuleType, TargetExpr,
20};
21use join_plan::{IntermediateJoinNode, JoinPlan};
22use util::{Counter, repeat_tuple};
23
24static MAGIC_RELATIONS: [&str; 1] = ["less_than"];
25
26pub fn parse_pipeline(
27    code_str: &rust_sitter::Spanned<String>,
28    get_span: &impl Fn((usize, usize)) -> Span,
29) -> Result<Pipeline, Vec<Diagnostic>> {
30    syn::LitStr::new(&code_str.value, get_span(code_str.span))
31        .parse()
32        .map_err(|err| {
33            vec![Diagnostic {
34                span: err.span(),
35                level: Level::Error,
36                message: format!("Failed to parse input pipeline: {}", err),
37            }]
38        })
39}
40
41pub fn parse_static(
42    code_str: &rust_sitter::Spanned<String>,
43    get_span: &impl Fn((usize, usize)) -> Span,
44) -> Result<syn::Expr, Vec<Diagnostic>> {
45    syn::LitStr::new(&code_str.value, get_span(code_str.span))
46        .parse()
47        .map_err(|err| {
48            vec![Diagnostic {
49                span: err.span(),
50                level: Level::Error,
51                message: format!("Failed to parse static expression: {}", err),
52            }]
53        })
54}
55
56pub fn gen_dfir_graph(literal: proc_macro2::Literal) -> Result<DfirGraph, Vec<Diagnostic>> {
57    let offset = {
58        // This includes the quotes, i.e. 'r#"my test"#' or '"hello\nworld"'.
59        let source_str = literal.to_string();
60        let mut source_chars = source_str.chars();
61        if Some('r') != source_chars.next() {
62            return Err(vec![Diagnostic {
63                span: literal.span(),
64                level: Level::Error,
65                message:
66                    r##"Input must be a raw string `r#"..."#` for correct diagnostic messages."##
67                        .to_owned(),
68            }]);
69        }
70        let hashes = source_chars.take_while(|&c| '#' == c).count();
71        2 + hashes
72    };
73
74    let get_span = |(start, end): (usize, usize)| {
75        let subspan = literal.subspan(start + offset..end + offset);
76        subspan.unwrap_or(Span::call_site())
77    };
78
79    let str_node: syn::LitStr = parse_quote!(#literal);
80    let actual_str = str_node.value();
81    let program: Program =
82        grammar::datalog::parse(&actual_str).map_err(|e| handle_errors(e, &get_span))?;
83
84    let mut inputs = Vec::new();
85    let mut outputs = Vec::new();
86    let mut persists = HashSet::new();
87    let mut asyncs = Vec::new();
88    let mut rules = Vec::new();
89    let mut statics = Vec::new();
90
91    for stmt in &program.rules {
92        match stmt {
93            Declaration::Input(_, ident, hf_code) => {
94                assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
95                inputs.push((ident, hf_code))
96            }
97            Declaration::Output(_, ident, hf_code) => {
98                assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
99                outputs.push((ident, hf_code))
100            }
101            Declaration::Persist(_, ident) => {
102                persists.insert(ident.name.clone());
103            }
104            Declaration::Async(_, ident, send_hf, recv_hf) => {
105                assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
106                asyncs.push((ident, send_hf, recv_hf))
107            }
108            Declaration::Rule(rule) => {
109                assert!(!MAGIC_RELATIONS.contains(&rule.target.name.name.as_str()));
110                rules.push(rule)
111            }
112            Declaration::Static(_, ident, hf_code) => {
113                assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
114                statics.push((ident, hf_code));
115            }
116        }
117    }
118
119    let mut flat_graph_builder = FlatGraphBuilder::new();
120    let mut tee_counter = HashMap::new();
121    let mut union_counter = HashMap::new();
122
123    let mut created_rules = HashSet::new();
124    for decl in &program.rules {
125        let target_ident = match decl {
126            Declaration::Input(_, ident, _) => ident.clone(),
127            Declaration::Output(_, ident, _) => ident.clone(),
128            Declaration::Persist(_, ident) => ident.clone(),
129            Declaration::Async(_, ident, _, _) => ident.clone(),
130            Declaration::Rule(rule) => rule.target.name.clone(),
131            Declaration::Static(_, ident, _) => ident.clone(),
132        };
133
134        if !created_rules.contains(&target_ident.value) {
135            created_rules.insert(target_ident.value.clone());
136            let insert_name = syn::Ident::new(
137                &format!("{}_insert", target_ident.name),
138                get_span(target_ident.span),
139            );
140            let read_name = syn::Ident::new(&target_ident.name, get_span(target_ident.span));
141
142            if persists.contains(&target_ident.value.name) {
143                // read outputs the *new* values for this tick
144                flat_graph_builder
145                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name = union() -> unique::<'tick>(); });
146                flat_graph_builder
147                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name = difference::<'tick, 'static>() -> tee(); });
148                flat_graph_builder
149                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name -> [pos] #read_name; });
150                flat_graph_builder
151                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name -> defer_tick() -> [neg] #read_name; });
152            } else {
153                flat_graph_builder
154                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name = union() -> unique::<'tick>(); });
155                flat_graph_builder
156                    .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name = #insert_name -> tee(); });
157            }
158        }
159    }
160
161    for (target, hf_code) in inputs {
162        let my_union_index = union_counter
163            .entry(target.name.clone())
164            .or_insert_with(|| 0..)
165            .next()
166            .expect("Out of union indices");
167
168        let my_union_index_lit =
169            syn::LitInt::new(&format!("{}", my_union_index), get_span(target.span));
170        let name = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
171
172        let input_pipeline: Pipeline = parse_pipeline(&hf_code.code, &get_span)?;
173
174        flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
175            #input_pipeline -> [#my_union_index_lit] #name;
176        });
177    }
178
179    for (target, hf_code) in outputs {
180        let my_tee_index = tee_counter
181            .entry(target.name.clone())
182            .or_insert_with(|| 0..)
183            .next()
184            .expect("Out of tee indices");
185
186        let my_tee_index_lit =
187            syn::LitInt::new(&format!("{}", my_tee_index), get_span(target.span));
188        let target_ident = syn::Ident::new(&target.name, get_span(target.span));
189
190        let output_pipeline: Pipeline = parse_pipeline(&hf_code.code, &get_span)?;
191        let output_pipeline = if persists.contains(&target.name) {
192            parse_quote_spanned! {get_span(target.span)=> persist::<'static>() -> #output_pipeline}
193        } else {
194            output_pipeline
195        };
196
197        flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
198            #target_ident [#my_tee_index_lit] -> #output_pipeline;
199        });
200    }
201
202    for (target, send_hf, recv_hf) in asyncs {
203        let async_send_pipeline = format!("{}_async_send", target.name);
204        let async_send_pipeline = syn::Ident::new(&async_send_pipeline, get_span(target.span));
205
206        let recv_union_index = union_counter
207            .entry(target.name.clone())
208            .or_insert_with(|| 0..)
209            .next()
210            .expect("Out of union indices");
211
212        let recv_union_index_lit =
213            syn::LitInt::new(&format!("{}", recv_union_index), get_span(target.span));
214        let target_ident =
215            syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
216
217        let send_pipeline: Pipeline = parse_pipeline(&send_hf.code, &get_span)?;
218        let recv_pipeline: Pipeline = parse_pipeline(&recv_hf.code, &get_span)?;
219
220        flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
221            #async_send_pipeline = union() -> unique::<'tick>() -> #send_pipeline;
222        });
223
224        flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
225            #recv_pipeline -> [#recv_union_index_lit] #target_ident;
226        });
227    }
228
229    for (target, hf_code) in statics {
230        let my_union_index = union_counter
231            .entry(target.name.clone())
232            .or_insert_with(|| 0..)
233            .next()
234            .expect("Out of union indices");
235
236        let my_union_index_lit =
237            syn::LitInt::new(&format!("{}", my_union_index), get_span(target.span));
238        let name = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
239
240        let static_expression: syn::Expr = parse_static(&hf_code.code, &get_span)?;
241
242        flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
243            source_iter(#static_expression) -> persist::<'static>() -> [#my_union_index_lit] #name;
244        });
245    }
246
247    let mut next_join_idx = 0..;
248    let mut diagnostics = Vec::new();
249    for rule in rules {
250        let plan = compute_join_plan(&rule.sources, &persists);
251        generate_rule(
252            plan,
253            rule,
254            &mut flat_graph_builder,
255            &mut tee_counter,
256            &mut union_counter,
257            &mut next_join_idx,
258            &persists,
259            &mut diagnostics,
260            &get_span,
261        );
262    }
263
264    if !diagnostics.is_empty() {
265        return Err(diagnostics);
266    }
267
268    let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
269
270    diagnostics.retain(Diagnostic::is_error);
271    if !diagnostics.is_empty() {
272        return Err(diagnostics);
273    }
274
275    if let Err(err) = flat_graph.merge_modules() {
276        diagnostics.push(err);
277        return Err(diagnostics);
278    }
279
280    eliminate_extra_unions_tees(&mut flat_graph);
281    Ok(flat_graph)
282}
283
284fn handle_errors(
285    errors: Vec<ParseError>,
286    get_span: &impl Fn((usize, usize)) -> Span,
287) -> Vec<Diagnostic> {
288    let mut diagnostics = vec![];
289    for error in errors {
290        let reason = error.reason;
291        let my_span = get_span((error.start, error.end));
292        match reason {
293            ParseErrorReason::UnexpectedToken(msg) => {
294                diagnostics.push(Diagnostic::spanned(
295                    my_span,
296                    Level::Error,
297                    format!("Unexpected Token: '{msg}'", msg = msg),
298                ));
299            }
300            ParseErrorReason::MissingToken(msg) => {
301                diagnostics.push(Diagnostic::spanned(
302                    my_span,
303                    Level::Error,
304                    format!("Missing Token: '{msg}'", msg = msg),
305                ));
306            }
307            ParseErrorReason::FailedNode(parse_errors) => {
308                if parse_errors.is_empty() {
309                    diagnostics.push(Diagnostic::spanned(
310                        my_span,
311                        Level::Error,
312                        "Failed to parse",
313                    ));
314                } else {
315                    diagnostics.extend(handle_errors(parse_errors, get_span));
316                }
317            }
318        }
319    }
320
321    diagnostics
322}
323
324pub fn dfir_graph_to_program(flat_graph: DfirGraph, root: TokenStream) -> TokenStream {
325    let partitioned_graph =
326        partition_graph(flat_graph).expect("Failed to partition (cycle detected).");
327
328    let mut diagnostics = Vec::new();
329    let code_tokens = partitioned_graph.as_code(&root, true, quote::quote!(), &mut diagnostics);
330    assert_eq!(
331        0,
332        diagnostics.len(),
333        "Operator diagnostic occured during codegen"
334    );
335
336    code_tokens
337}
338
339#[expect(clippy::too_many_arguments, reason = "internal code")]
340fn generate_rule(
341    plan: JoinPlan<'_>,
342    rule: &rust_sitter::Spanned<Rule>,
343    flat_graph_builder: &mut FlatGraphBuilder,
344    tee_counter: &mut HashMap<String, Counter>,
345    union_counter: &mut HashMap<String, Counter>,
346    next_join_idx: &mut Counter,
347    persists: &HashSet<String>,
348    diagnostics: &mut Vec<Diagnostic>,
349    get_span: &impl Fn((usize, usize)) -> Span,
350) {
351    let target = &rule.target.name;
352    let target_ident = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
353
354    let out_expanded = join_plan::expand_join_plan(
355        &plan,
356        flat_graph_builder,
357        tee_counter,
358        next_join_idx,
359        rule.span,
360        diagnostics,
361        get_span,
362    );
363
364    let after_join = apply_aggregations(
365        rule,
366        &out_expanded,
367        persists.contains(&target.name),
368        diagnostics,
369        get_span,
370    );
371
372    let my_union_index = union_counter
373        .entry(target.name.clone())
374        .or_insert_with(|| 0..)
375        .next()
376        .expect("Out of union indices");
377
378    let my_union_index_lit = syn::LitInt::new(&format!("{}", my_union_index), Span::call_site());
379
380    let after_join_and_send: Pipeline = match rule.rule_type.value {
381        RuleType::Sync(_) => {
382            if rule.target.at_node.is_some() {
383                panic!("Rule must be async to send data to other nodes")
384            }
385
386            parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> [#my_union_index_lit] #target_ident)
387        }
388        RuleType::NextTick(_) => {
389            if rule.target.at_node.is_some() {
390                panic!("Rule must be async to send data to other nodes")
391            }
392
393            parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> defer_tick() -> [#my_union_index_lit] #target_ident)
394        }
395        RuleType::Async(_) => {
396            if rule.target.at_node.is_none() {
397                panic!("Async rules are only for sending data to other nodes")
398            }
399
400            let exprs_get_data = rule
401                .target
402                .fields
403                .iter()
404                .enumerate()
405                .map(|(i, f)| -> syn::Expr {
406                    let syn_index = syn::Index::from(i);
407                    parse_quote_spanned!(get_span(f.span)=> v.#syn_index)
408                });
409
410            let syn_target_index = syn::Index::from(rule.target.fields.len());
411
412            let v_type = repeat_tuple::<syn::Type, syn::Type>(
413                || parse_quote!(_),
414                rule.target.fields.len() + 1,
415            );
416
417            let send_pipeline_ident = syn::Ident::new(
418                &format!("{}_async_send", &rule.target.name.name),
419                get_span(rule.target.name.span),
420            );
421
422            parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> map(|v: #v_type| (v.#syn_target_index, (#(#exprs_get_data, )*))) -> #send_pipeline_ident)
423        }
424    };
425
426    let out_name = out_expanded.name;
427    // If the output comes with a tee index, we must read with that. This only happens when we are
428    // directly outputting a transformation of a single relation on the RHS.
429    let out_indexing = out_expanded.tee_idx.map(|i| Indexing {
430        bracket_token: syn::token::Bracket::default(),
431        index: PortIndex::Int(IndexInt {
432            value: i,
433            span: Span::call_site(),
434        }),
435    });
436    flat_graph_builder.add_statement(DfirStatement::Pipeline(PipelineStatement {
437        pipeline: Pipeline::Link(PipelineLink {
438            lhs: Box::new(parse_quote!(#out_name #out_indexing)), // out_name[idx]
439            arrow: parse_quote!(->),
440            rhs: Box::new(after_join_and_send),
441        }),
442        semi_token: Token![;](Span::call_site()),
443    }));
444}
445
446fn compute_join_plan<'a>(sources: &'a [Atom], persisted_rules: &HashSet<String>) -> JoinPlan<'a> {
447    // TODO(shadaj): smarter plans
448    let mut plan: JoinPlan = sources
449        .iter()
450        .filter_map(|x| match x {
451            Atom::PosRelation(e) => {
452                if !MAGIC_RELATIONS.contains(&e.name.name.as_str()) {
453                    Some(JoinPlan::Source(e, persisted_rules.contains(&e.name.name)))
454                } else {
455                    None
456                }
457            }
458            _ => None,
459        })
460        .reduce(|a, b| JoinPlan::Join(Box::new(a), Box::new(b)))
461        .unwrap();
462
463    plan = sources
464        .iter()
465        .filter_map(|x| match x {
466            Atom::NegRelation(_, e) => {
467                Some(JoinPlan::Source(e, persisted_rules.contains(&e.name.name)))
468            }
469            _ => None,
470        })
471        .fold(plan, |a, b| JoinPlan::AntiJoin(Box::new(a), Box::new(b)));
472
473    let predicates = sources
474        .iter()
475        .filter_map(|x| match x {
476            Atom::Predicate(e) => Some(e),
477            _ => None,
478        })
479        .collect::<Vec<_>>();
480
481    if !predicates.is_empty() {
482        plan = JoinPlan::Predicate(predicates, Box::new(plan))
483    }
484
485    plan = sources.iter().fold(plan, |acc, atom| match atom {
486        Atom::PosRelation(e) => {
487            if MAGIC_RELATIONS.contains(&e.name.name.as_str()) {
488                match e.name.name.as_str() {
489                    "less_than" => JoinPlan::MagicNatLt(
490                        Box::new(acc),
491                        e.fields[0].value.clone(),
492                        e.fields[1].value.clone(),
493                    ),
494                    o => panic!("Unknown magic relation {}", o),
495                }
496            } else {
497                acc
498            }
499        }
500        _ => acc,
501    });
502
503    plan
504}
505
506pub(crate) fn gen_value_expr(
507    expr: &IntExpr,
508    lookup_ident: &mut impl FnMut(&rust_sitter::Spanned<Ident>) -> syn::Expr,
509    get_span: &dyn Fn((usize, usize)) -> Span,
510) -> syn::Expr {
511    match expr {
512        IntExpr::Ident(ident) => lookup_ident(ident),
513        IntExpr::Integer(i) => syn::Expr::Lit(syn::ExprLit {
514            attrs: Vec::new(),
515            lit: syn::Lit::Int(syn::LitInt::new(&i.to_string(), get_span(i.span))),
516        }),
517        IntExpr::Parenthesized(_, e, _) => {
518            let inner = gen_value_expr(e, lookup_ident, get_span);
519            parse_quote!((#inner))
520        }
521        IntExpr::Add(l, _, r) => {
522            let l = gen_value_expr(l, lookup_ident, get_span);
523            let r = gen_value_expr(r, lookup_ident, get_span);
524            parse_quote!(#l + #r)
525        }
526        IntExpr::Sub(l, _, r) => {
527            let l = gen_value_expr(l, lookup_ident, get_span);
528            let r = gen_value_expr(r, lookup_ident, get_span);
529            parse_quote!(#l - #r)
530        }
531        IntExpr::Mul(l, _, r) => {
532            let l = gen_value_expr(l, lookup_ident, get_span);
533            let r = gen_value_expr(r, lookup_ident, get_span);
534            parse_quote!(#l * #r)
535        }
536        IntExpr::Mod(l, _, r) => {
537            let l = gen_value_expr(l, lookup_ident, get_span);
538            let r = gen_value_expr(r, lookup_ident, get_span);
539            parse_quote!(#l % #r)
540        }
541    }
542}
543
544fn gen_target_expr(
545    expr: &TargetExpr,
546    lookup_ident: &mut impl FnMut(&rust_sitter::Spanned<Ident>) -> syn::Expr,
547    get_span: &dyn Fn((usize, usize)) -> Span,
548) -> syn::Expr {
549    match expr {
550        TargetExpr::Expr(expr) => gen_value_expr(expr, lookup_ident, get_span),
551        TargetExpr::Aggregation(Aggregation::Count(_)) => parse_quote!(()),
552        TargetExpr::Aggregation(Aggregation::CountUnique(_, _, keys, _))
553        | TargetExpr::Aggregation(Aggregation::CollectVec(_, _, keys, _)) => {
554            let keys = keys
555                .iter()
556                .map(|k| gen_value_expr(&IntExpr::Ident(k.clone()), lookup_ident, get_span))
557                .collect::<Vec<_>>();
558            parse_quote!((#(#keys),*))
559        }
560        TargetExpr::Aggregation(
561            Aggregation::Min(_, _, a, _)
562            | Aggregation::Max(_, _, a, _)
563            | Aggregation::Sum(_, _, a, _)
564            | Aggregation::Choose(_, _, a, _),
565        ) => gen_value_expr(&IntExpr::Ident(a.clone()), lookup_ident, get_span),
566        TargetExpr::Index(_, _, _) => unreachable!(),
567    }
568}
569
570fn apply_aggregations(
571    rule: &Rule,
572    out_expanded: &IntermediateJoinNode,
573    consumer_is_persist: bool,
574    diagnostics: &mut Vec<Diagnostic>,
575    get_span: &impl Fn((usize, usize)) -> Span,
576) -> Pipeline {
577    let mut field_use_count = HashMap::new();
578    for field in rule
579        .target
580        .fields
581        .iter()
582        .chain(rule.target.at_node.iter().map(|n| &n.node))
583    {
584        for ident in field.idents() {
585            field_use_count
586                .entry(ident.name.clone())
587                .and_modify(|e| *e += 1)
588                .or_insert(1);
589        }
590    }
591
592    let mut aggregations = vec![];
593    let mut fold_keyed_exprs = vec![];
594    let mut agg_exprs = vec![];
595
596    let mut field_use_cur = HashMap::new();
597    let mut has_index = false;
598
599    let mut copy_group_key_lookups: Vec<syn::Expr> = vec![];
600    let mut after_group_lookups: Vec<syn::Expr> = vec![];
601    let mut group_key_idx = 0;
602    let mut agg_idx = 0;
603
604    for field in rule
605        .target
606        .fields
607        .iter()
608        .chain(rule.target.at_node.iter().map(|n| &n.node))
609    {
610        if matches!(field.deref(), TargetExpr::Index(_, _, _)) {
611            has_index = true;
612            after_group_lookups
613                .push(parse_quote_spanned!(get_span(field.span)=> __enumerate_index));
614        } else {
615            let expr: syn::Expr = gen_target_expr(
616                field,
617                &mut |ident| {
618                    if let Some(col) = out_expanded.variable_mapping.get(&ident.name) {
619                        let cur_count = field_use_cur
620                            .entry(ident.name.clone())
621                            .and_modify(|e| *e += 1)
622                            .or_insert(1);
623
624                        let source_col_idx = syn::Index::from(*col);
625                        let base = parse_quote_spanned!(get_span(ident.span)=> row.#source_col_idx);
626
627                        if *cur_count < field_use_count[&ident.name]
628                            && field_use_count[&ident.name] > 1
629                        {
630                            parse_quote!(#base.clone())
631                        } else {
632                            base
633                        }
634                    } else {
635                        diagnostics.push(Diagnostic::spanned(
636                            get_span(ident.span),
637                            Level::Error,
638                            format!("Could not find column {} in RHS of rule", &ident.name),
639                        ));
640                        parse_quote!(())
641                    }
642                },
643                get_span,
644            );
645
646            match &field.value {
647                TargetExpr::Expr(_) => {
648                    fold_keyed_exprs.push(expr);
649
650                    let idx = syn::Index::from(group_key_idx);
651                    after_group_lookups.push(parse_quote_spanned!(get_span(field.span)=> g.#idx));
652                    copy_group_key_lookups
653                        .push(parse_quote_spanned!(get_span(field.span)=> g.#idx));
654                    group_key_idx += 1;
655                }
656                TargetExpr::Aggregation(a) => {
657                    aggregations.push(a.clone());
658                    agg_exprs.push(expr);
659
660                    match a {
661                        Aggregation::CountUnique(..) => {
662                            let idx = syn::Index::from(agg_idx);
663                            after_group_lookups.push(
664                                parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap().1),
665                            );
666                            agg_idx += 1;
667                        }
668                        Aggregation::CollectVec(..) => {
669                            let idx = syn::Index::from(agg_idx);
670                            after_group_lookups
671                                .push(parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap().into_iter().collect::<Vec<_>>()));
672                            agg_idx += 1;
673                        }
674                        _ => {
675                            let idx = syn::Index::from(agg_idx);
676                            after_group_lookups
677                                .push(parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap()));
678                            agg_idx += 1;
679                        }
680                    }
681                }
682                TargetExpr::Index(_, _, _) => unreachable!(),
683            }
684        }
685    }
686
687    let flattened_tuple_type = &out_expanded.tuple_type;
688
689    let fold_keyed_input_type =
690        repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(_), fold_keyed_exprs.len());
691
692    let after_group_pipeline: Pipeline = if has_index {
693        if out_expanded.persisted && agg_exprs.is_empty() {
694            // if there is an aggregation, we will use a group which replays so we should use `'tick` instead
695            parse_quote!(enumerate::<'static>() -> map(|(__enumerate_index, (g, a)): (_, (#fold_keyed_input_type, _))| (#(#after_group_lookups, )*)))
696        } else {
697            parse_quote!(enumerate::<'tick>() -> map(|(__enumerate_index, (g, a)): (_, (#fold_keyed_input_type, _))| (#(#after_group_lookups, )*)))
698        }
699    } else {
700        parse_quote!(map(|(g, a): (#fold_keyed_input_type, _)| (#(#after_group_lookups, )*)))
701    };
702
703    let pre_fold_keyed_map: Pipeline = parse_quote!(map(|row: #flattened_tuple_type| ((#(#fold_keyed_exprs, )*), (#(#agg_exprs, )*))));
704
705    if agg_exprs.is_empty() {
706        if out_expanded.persisted && !consumer_is_persist {
707            parse_quote!(#pre_fold_keyed_map -> #after_group_pipeline -> persist::<'static>())
708        } else {
709            parse_quote!(#pre_fold_keyed_map -> #after_group_pipeline)
710        }
711    } else {
712        let agg_initial =
713            repeat_tuple::<syn::Expr, syn::Expr>(|| parse_quote!(None), agg_exprs.len());
714
715        let agg_input_type =
716            repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(_), agg_exprs.len());
717        let agg_type: syn::Type =
718            repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(Option<_>), agg_exprs.len());
719
720        let fold_keyed_stmts: Vec<syn::Stmt> = aggregations
721            .iter()
722            .enumerate()
723            .map(|(i, agg)| {
724                let idx = syn::Index::from(i);
725                let old_at_index: syn::Expr = parse_quote!(old.#idx);
726                let val_at_index: syn::Expr = parse_quote!(val.#idx);
727
728                let agg_expr: syn::Expr = match &agg {
729                    Aggregation::Min(..) => {
730                        parse_quote!(std::cmp::min(prev, #val_at_index))
731                    }
732                    Aggregation::Max(..) => {
733                        parse_quote!(std::cmp::max(prev, #val_at_index))
734                    }
735                    Aggregation::Sum(..) => {
736                        parse_quote!(prev + #val_at_index)
737                    }
738                    Aggregation::Count(..) => {
739                        parse_quote!(prev + 1)
740                    }
741                    Aggregation::CountUnique(..) => {
742                        parse_quote!({
743                            let prev: (dfir_rs::rustc_hash::FxHashSet<_>, _) = prev;
744                            let mut set: dfir_rs::rustc_hash::FxHashSet<_> = prev.0;
745                            if set.insert(#val_at_index) {
746                                (set, prev.1 + 1)
747                            } else {
748                                (set, prev.1)
749                            }
750                        })
751                    }
752                    Aggregation::CollectVec(..) => {
753                        parse_quote!({
754                            let mut set: dfir_rs::rustc_hash::FxHashSet<_> = prev;
755                            set.insert(#val_at_index);
756                            set
757                        })
758                    }
759                    Aggregation::Choose(..) => {
760                        parse_quote!(prev) // choose = select any 1 element from the relation. By default we select the 1st.
761                    }
762                };
763
764                let agg_initial: syn::Expr = match &agg {
765                    Aggregation::Min(..)
766                    | Aggregation::Max(..)
767                    | Aggregation::Sum(..)
768                    | Aggregation::Choose(..) => {
769                        parse_quote!(#val_at_index)
770                    }
771                    Aggregation::Count(..) => {
772                        parse_quote!(1)
773                    }
774                    Aggregation::CountUnique(..) => {
775                        parse_quote!({
776                            let mut set = dfir_rs::rustc_hash::FxHashSet::<_>::default();
777                            set.insert(#val_at_index);
778                            (set, 1)
779                        })
780                    }
781                    Aggregation::CollectVec(..) => {
782                        parse_quote!({
783                            let mut set = dfir_rs::rustc_hash::FxHashSet::<_>::default();
784                            set.insert(#val_at_index);
785                            set
786                        })
787                    }
788                };
789
790                parse_quote! {
791                    #old_at_index = if let Some(prev) = #old_at_index.take() {
792                        Some(#agg_expr)
793                    } else {
794                        Some(#agg_initial)
795                    };
796                }
797            })
798            .collect();
799
800        let fold_keyed_fn: syn::Expr = parse_quote!(|old: &mut #agg_type, val: #agg_input_type| {
801            #(#fold_keyed_stmts)*
802        });
803
804        if out_expanded.persisted {
805            parse_quote! {
806                #pre_fold_keyed_map -> fold_keyed::<'static, #fold_keyed_input_type, #agg_type>(|| #agg_initial, #fold_keyed_fn) -> #after_group_pipeline
807            }
808        } else {
809            parse_quote! {
810                #pre_fold_keyed_map -> fold_keyed::<'tick, #fold_keyed_input_type, #agg_type>(|| #agg_initial, #fold_keyed_fn) -> #after_group_pipeline
811            }
812        }
813    }
814}
815
816#[cfg(test)]
817mod tests {
818    use syn::parse_quote;
819
820    use super::gen_dfir_graph;
821
822    macro_rules! test_snapshots {
823        ($program:literal) => {
824            let flat_graph = gen_dfir_graph(parse_quote!($program)).unwrap();
825
826            let flat_graph_ref = &flat_graph;
827            insta::with_settings!({snapshot_suffix => "surface_graph"}, {
828                insta::assert_snapshot!(flat_graph_ref.surface_syntax_string());
829            });
830        };
831    }
832
833    #[test]
834    fn minimal_program() {
835        test_snapshots!(
836            r#"
837            .input input `source_stream(input)`
838            .output out `for_each(|v| out.send(v).unwrap())`
839
840            out(y, x) :- input(x, y).
841            "#
842        );
843    }
844
845    #[test]
846    fn join_with_self() {
847        test_snapshots!(
848            r#"
849            .input input `source_stream(input)`
850            .output out `for_each(|v| out.send(v).unwrap())`
851
852            out(x, y) :- input(x, y), input(y, x).
853            "#
854        );
855    }
856
857    #[test]
858    fn wildcard_fields() {
859        test_snapshots!(
860            r#"
861            .input input `source_stream(input)`
862            .output out `for_each(|v| out.send(v).unwrap())`
863
864            out(x) :- input(x, _), input(_, x).
865            "#
866        );
867    }
868
869    #[test]
870    fn join_with_other() {
871        test_snapshots!(
872            r#"
873            .input in1 `source_stream(in1)`
874            .input in2 `source_stream(in2)`
875            .output out `for_each(|v| out.send(v).unwrap())`
876
877            out(x, y) :- in1(x, y), in2(y, x).
878            "#
879        );
880    }
881
882    #[test]
883    fn multiple_contributors() {
884        test_snapshots!(
885            r#"
886            .input in1 `source_stream(in1)`
887            .input in2 `source_stream(in2)`
888            .output out `for_each(|v| out.send(v).unwrap())`
889
890            out(x, y) :- in1(x, y).
891            out(x, y) :- in2(y, x).
892            "#
893        );
894    }
895
896    #[test]
897    fn transitive_closure() {
898        test_snapshots!(
899            r#"
900            .input edges `source_stream(edges)`
901            .input seed_reachable `source_stream(seed_reachable)`
902            .output reachable `for_each(|v| reachable.send(v).unwrap())`
903
904            reachable(x) :- seed_reachable(x).
905            reachable(y) :- reachable(x), edges(x, y).
906            "#
907        );
908    }
909
910    #[test]
911    fn single_column_program() {
912        test_snapshots!(
913            r#"
914            .input in1 `source_stream(in1)`
915            .input in2 `source_stream(in2)`
916            .output out `for_each(|v| out.send(v).unwrap())`
917
918            out(x) :- in1(x), in2(x).
919            "#
920        );
921    }
922
923    #[test]
924    fn triple_relation_join() {
925        test_snapshots!(
926            r#"
927            .input in1 `source_stream(in1)`
928            .input in2 `source_stream(in2)`
929            .input in3 `source_stream(in3)`
930            .output out `for_each(|v| out.send(v).unwrap())`
931
932            out(d, c, b, a) :- in1(a, b), in2(b, c), in3(c, d).
933            "#
934        );
935    }
936
937    #[test]
938    fn local_constraints() {
939        test_snapshots!(
940            r#"
941            .input input `source_stream(input)`
942            .output out `for_each(|v| out.send(v).unwrap())`
943
944            out(x, x) :- input(x, x).
945            "#
946        );
947
948        test_snapshots!(
949            r#"
950            .input input `source_stream(input)`
951            .output out `for_each(|v| out.send(v).unwrap())`
952
953            out(x, x, y, y) :- input(x, x, y, y).
954            "#
955        );
956    }
957
958    #[test]
959    fn test_simple_filter() {
960        test_snapshots!(
961            r#"
962            .input input `source_stream(input)`
963            .output out `for_each(|v| out.send(v).unwrap())`
964
965            out(x, y) :- input(x, y), ( x > y ), ( y == x ).
966            "#
967        );
968    }
969
970    #[test]
971    fn test_anti_join() {
972        test_snapshots!(
973            r#"
974            .input ints_1 `source_stream(ints_1)`
975            .input ints_2 `source_stream(ints_2)`
976            .input ints_3 `source_stream(ints_3)`
977            .output result `for_each(|v| result.send(v).unwrap())`
978
979            result(x, z) :- ints_1(x, y), ints_2(y, z), !ints_3(y)
980            "#
981        );
982    }
983
984    #[test]
985    fn test_max() {
986        test_snapshots!(
987            r#"
988            .input ints `source_stream(ints)`
989            .output result `for_each(|v| result.send(v).unwrap())`
990
991            result(max(a), b) :- ints(a, b)
992            "#
993        );
994    }
995
996    #[test]
997    fn test_max_all() {
998        test_snapshots!(
999            r#"
1000            .input ints `source_stream(ints)`
1001            .output result `for_each(|v| result.send(v).unwrap())`
1002
1003            result(max(a), max(b)) :- ints(a, b)
1004            "#
1005        );
1006    }
1007
1008    #[test]
1009    fn test_send_to_node() {
1010        test_snapshots!(
1011            r#"
1012            .input ints `source_stream(ints)`
1013            .output result `for_each(|v| result.send(v).unwrap())`
1014            .async result `for_each(|(node, data)| async_send_result(node, data))` `source_stream(async_receive_result)`
1015
1016            result@b(a) :~ ints(a, b)
1017            "#
1018        );
1019    }
1020
1021    #[test]
1022    fn test_aggregations_and_comments() {
1023        test_snapshots!(
1024            r#"
1025            # david doesn't think this line of code will execute
1026            .input ints `source_stream(ints)`
1027            .output result `for_each(|v| result.send(v).unwrap())`
1028            .output result2 `for_each(|v| result2.send(v).unwrap())`
1029
1030            result(count(a), b) :- ints(a, b)
1031            result(sum(a), b) :+ ints(a, b)
1032            result2(choose(a), b) :- ints(a, b)
1033            "#
1034        );
1035    }
1036
1037    #[test]
1038    fn test_aggregations_fold_keyed_expr() {
1039        test_snapshots!(
1040            r#"
1041            .input ints `source_stream(ints)`
1042            .output result `for_each(|v| result.send(v).unwrap())`
1043
1044            result(a % 2, sum(b)) :- ints(a, b)
1045            "#
1046        );
1047    }
1048
1049    #[test]
1050    fn test_non_copy_but_clone() {
1051        test_snapshots!(
1052            r#"
1053            .input strings `source_stream(strings)`
1054            .output result `for_each(|v| result.send(v).unwrap())`
1055
1056            result(a, a) :- strings(a)
1057            "#
1058        );
1059    }
1060
1061    #[test]
1062    fn test_expr_lhs() {
1063        test_snapshots!(
1064            r#"
1065            .input ints `source_stream(ints)`
1066            .output result `for_each(|v| result.send(v).unwrap())`
1067
1068            result(123) :- ints(a)
1069            result(a + 123) :- ints(a)
1070            result(a + a) :- ints(a)
1071            result(123 - a) :- ints(a)
1072            result(123 % (a + 5)) :- ints(a)
1073            result(a * 5) :- ints(a)
1074            "#
1075        );
1076    }
1077
1078    #[test]
1079    fn test_expr_predicate() {
1080        test_snapshots!(
1081            r#"
1082            .input ints `source_stream(ints)`
1083            .output result `for_each(|v| result.send(v).unwrap())`
1084
1085            result(1) :- ints(a), (a == 0)
1086            result(2) :- ints(a), (a != 0)
1087            result(3) :- ints(a), (a - 1 == 0)
1088            result(4) :- ints(a), (a - 1 == 1 - 1)
1089            "#
1090        );
1091    }
1092
1093    #[test]
1094    fn test_persist() {
1095        test_snapshots!(
1096            r#"
1097            .input ints1 `source_stream(ints1)`
1098            .persist ints1
1099
1100            .input ints2 `source_stream(ints2)`
1101            .persist ints2
1102
1103            .input ints3 `source_stream(ints3)`
1104
1105            .output result `for_each(|v| result.send(v).unwrap())`
1106            .output result2 `for_each(|v| result2.send(v).unwrap())`
1107            .output result3 `for_each(|v| result3.send(v).unwrap())`
1108            .output result4 `for_each(|v| result4.send(v).unwrap())`
1109
1110            result(a, b, c) :- ints1(a), ints2(b), ints3(c)
1111            result2(a) :- ints1(a), !ints2(a)
1112
1113            intermediate(a) :- ints1(a)
1114            result3(a) :- intermediate(a)
1115
1116            .persist intermediate_persist
1117            intermediate_persist(a) :- ints1(a)
1118            result4(a) :- intermediate_persist(a)
1119            "#
1120        );
1121    }
1122
1123    #[test]
1124    fn test_persist_uniqueness() {
1125        test_snapshots!(
1126            r#"
1127            .persist ints1
1128
1129            .input ints2 `source_stream(ints2)`
1130
1131            ints1(a) :- ints2(a)
1132
1133            .output result `for_each(|v| result.send(v).unwrap())`
1134
1135            result(count(a)) :- ints1(a)
1136            "#
1137        );
1138    }
1139
1140    #[test]
1141    fn test_wildcard_join_count() {
1142        test_snapshots!(
1143            r#"
1144            .input ints1 `source_stream(ints1)`
1145            .input ints2 `source_stream(ints2)`
1146
1147            .output result `for_each(|v| result.send(v).unwrap())`
1148            .output result2 `for_each(|v| result2.send(v).unwrap())`
1149
1150            result(count(*)) :- ints1(a, _), ints2(a)
1151            result2(count(a)) :- ints1(a, _), ints2(a)
1152            "#
1153        );
1154    }
1155
1156    #[test]
1157    fn test_index() {
1158        test_snapshots!(
1159            r#"
1160            .input ints `source_stream(ints)`
1161
1162            .output result `for_each(|v| result.send(v).unwrap())`
1163            .output result2 `for_each(|v| result2.send(v).unwrap())`
1164            .output result3 `for_each(|v| result3.send(v).unwrap())`
1165            .output result4 `for_each(|v| result4.send(v).unwrap())`
1166
1167            .persist result5
1168            .output result5 `for_each(|v| result5.send(v).unwrap())`
1169
1170            result(a, b, index()) :- ints(a, b)
1171            result2(a, count(b), index()) :- ints(a, b)
1172
1173            .persist ints_persisted
1174            ints_persisted(a, b) :- ints(a, b)
1175
1176            result3(a, b, index()) :- ints_persisted(a, b)
1177            result4(a, count(b), index()) :- ints_persisted(a, b)
1178            result5(a, b, index()) :- ints_persisted(a, b)
1179            "#
1180        );
1181    }
1182
1183    #[test]
1184    fn test_collect_vec() {
1185        test_snapshots!(
1186            r#"
1187            .input ints1 `source_stream(ints1)`
1188            .input ints2 `source_stream(ints2)`
1189
1190            .output result `for_each(|v| result.send(v).unwrap())`
1191
1192            result(collect_vec(a, b)) :- ints1(a), ints2(b)
1193            "#
1194        );
1195    }
1196
1197    #[test]
1198    fn test_flatten() {
1199        test_snapshots!(
1200            r#"
1201            .input ints1 `source_stream(ints1)`
1202
1203            .output result `for_each(|v| result.send(v).unwrap())`
1204
1205            result(a, b) :- ints1(a, *b)
1206            "#
1207        );
1208    }
1209
1210    #[test]
1211    fn test_detuple() {
1212        test_snapshots!(
1213            r#"
1214            .input ints1 `source_stream(ints1)`
1215
1216            .output result `for_each(|v| result.send(v).unwrap())`
1217
1218            result(a, b) :- ints1((a, b))
1219            "#
1220        );
1221    }
1222
1223    #[test]
1224    fn test_multi_detuple() {
1225        test_snapshots!(
1226            r#"
1227            .input ints1 `source_stream(ints1)`
1228
1229            .output result `for_each(|v| result.send(v).unwrap())`
1230
1231            result(a, b, c, d) :- ints1((a, b), (c, d))
1232            "#
1233        );
1234    }
1235
1236    #[test]
1237    fn test_flat_then_detuple() {
1238        test_snapshots!(
1239            r#"
1240            .input ints1 `source_stream(ints1)`
1241
1242            .output result `for_each(|v| result.send(v).unwrap())`
1243
1244            result(a, b) :- ints1(*(a, b))
1245            "#
1246        );
1247    }
1248
1249    #[test]
1250    fn test_detuple_then_flat() {
1251        test_snapshots!(
1252            r#"
1253            .input ints1 `source_stream(ints1)`
1254
1255            .output result `for_each(|v| result.send(v).unwrap())`
1256
1257            result(a, b) :- ints1((*a, *b))
1258            "#
1259        );
1260    }
1261}