1use std::collections::{HashMap, HashSet};
2use std::ops::Deref;
3
4pub use dfir_lang::diagnostic;
5use dfir_lang::diagnostic::{Diagnostic, Level};
6use dfir_lang::graph::{DfirGraph, FlatGraphBuilder, eliminate_extra_unions_tees, partition_graph};
7use dfir_lang::parse::{
8 DfirStatement, IndexInt, Indexing, Pipeline, PipelineLink, PipelineStatement, PortIndex,
9};
10use proc_macro2::{Span, TokenStream};
11use rust_sitter::errors::{ParseError, ParseErrorReason};
12use syn::{Token, parse_quote, parse_quote_spanned};
13
14mod grammar;
15mod join_plan;
16mod util;
17
18use grammar::datalog::{
19 Aggregation, Atom, Declaration, Ident, IntExpr, Program, Rule, RuleType, TargetExpr,
20};
21use join_plan::{IntermediateJoinNode, JoinPlan};
22use util::{Counter, repeat_tuple};
23
24static MAGIC_RELATIONS: [&str; 1] = ["less_than"];
25
26pub fn parse_pipeline(
27 code_str: &rust_sitter::Spanned<String>,
28 get_span: &impl Fn((usize, usize)) -> Span,
29) -> Result<Pipeline, Vec<Diagnostic>> {
30 syn::LitStr::new(&code_str.value, get_span(code_str.span))
31 .parse()
32 .map_err(|err| {
33 vec![Diagnostic {
34 span: err.span(),
35 level: Level::Error,
36 message: format!("Failed to parse input pipeline: {}", err),
37 }]
38 })
39}
40
41pub fn parse_static(
42 code_str: &rust_sitter::Spanned<String>,
43 get_span: &impl Fn((usize, usize)) -> Span,
44) -> Result<syn::Expr, Vec<Diagnostic>> {
45 syn::LitStr::new(&code_str.value, get_span(code_str.span))
46 .parse()
47 .map_err(|err| {
48 vec![Diagnostic {
49 span: err.span(),
50 level: Level::Error,
51 message: format!("Failed to parse static expression: {}", err),
52 }]
53 })
54}
55
56pub fn gen_dfir_graph(literal: proc_macro2::Literal) -> Result<DfirGraph, Vec<Diagnostic>> {
57 let offset = {
58 let source_str = literal.to_string();
60 let mut source_chars = source_str.chars();
61 if Some('r') != source_chars.next() {
62 return Err(vec![Diagnostic {
63 span: literal.span(),
64 level: Level::Error,
65 message:
66 r##"Input must be a raw string `r#"..."#` for correct diagnostic messages."##
67 .to_owned(),
68 }]);
69 }
70 let hashes = source_chars.take_while(|&c| '#' == c).count();
71 2 + hashes
72 };
73
74 let get_span = |(start, end): (usize, usize)| {
75 let subspan = literal.subspan(start + offset..end + offset);
76 subspan.unwrap_or(Span::call_site())
77 };
78
79 let str_node: syn::LitStr = parse_quote!(#literal);
80 let actual_str = str_node.value();
81 let program: Program =
82 grammar::datalog::parse(&actual_str).map_err(|e| handle_errors(e, &get_span))?;
83
84 let mut inputs = Vec::new();
85 let mut outputs = Vec::new();
86 let mut persists = HashSet::new();
87 let mut asyncs = Vec::new();
88 let mut rules = Vec::new();
89 let mut statics = Vec::new();
90
91 for stmt in &program.rules {
92 match stmt {
93 Declaration::Input(_, ident, hf_code) => {
94 assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
95 inputs.push((ident, hf_code))
96 }
97 Declaration::Output(_, ident, hf_code) => {
98 assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
99 outputs.push((ident, hf_code))
100 }
101 Declaration::Persist(_, ident) => {
102 persists.insert(ident.name.clone());
103 }
104 Declaration::Async(_, ident, send_hf, recv_hf) => {
105 assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
106 asyncs.push((ident, send_hf, recv_hf))
107 }
108 Declaration::Rule(rule) => {
109 assert!(!MAGIC_RELATIONS.contains(&rule.target.name.name.as_str()));
110 rules.push(rule)
111 }
112 Declaration::Static(_, ident, hf_code) => {
113 assert!(!MAGIC_RELATIONS.contains(&ident.name.as_str()));
114 statics.push((ident, hf_code));
115 }
116 }
117 }
118
119 let mut flat_graph_builder = FlatGraphBuilder::new();
120 let mut tee_counter = HashMap::new();
121 let mut union_counter = HashMap::new();
122
123 let mut created_rules = HashSet::new();
124 for decl in &program.rules {
125 let target_ident = match decl {
126 Declaration::Input(_, ident, _) => ident.clone(),
127 Declaration::Output(_, ident, _) => ident.clone(),
128 Declaration::Persist(_, ident) => ident.clone(),
129 Declaration::Async(_, ident, _, _) => ident.clone(),
130 Declaration::Rule(rule) => rule.target.name.clone(),
131 Declaration::Static(_, ident, _) => ident.clone(),
132 };
133
134 if !created_rules.contains(&target_ident.value) {
135 created_rules.insert(target_ident.value.clone());
136 let insert_name = syn::Ident::new(
137 &format!("{}_insert", target_ident.name),
138 get_span(target_ident.span),
139 );
140 let read_name = syn::Ident::new(&target_ident.name, get_span(target_ident.span));
141
142 if persists.contains(&target_ident.value.name) {
143 flat_graph_builder
145 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name = union() -> unique::<'tick>(); });
146 flat_graph_builder
147 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name = difference::<'tick, 'static>() -> tee(); });
148 flat_graph_builder
149 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name -> [pos] #read_name; });
150 flat_graph_builder
151 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name -> defer_tick() -> [neg] #read_name; });
152 } else {
153 flat_graph_builder
154 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #insert_name = union() -> unique::<'tick>(); });
155 flat_graph_builder
156 .add_statement(parse_quote_spanned!{get_span(target_ident.span)=> #read_name = #insert_name -> tee(); });
157 }
158 }
159 }
160
161 for (target, hf_code) in inputs {
162 let my_union_index = union_counter
163 .entry(target.name.clone())
164 .or_insert_with(|| 0..)
165 .next()
166 .expect("Out of union indices");
167
168 let my_union_index_lit =
169 syn::LitInt::new(&format!("{}", my_union_index), get_span(target.span));
170 let name = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
171
172 let input_pipeline: Pipeline = parse_pipeline(&hf_code.code, &get_span)?;
173
174 flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
175 #input_pipeline -> [#my_union_index_lit] #name;
176 });
177 }
178
179 for (target, hf_code) in outputs {
180 let my_tee_index = tee_counter
181 .entry(target.name.clone())
182 .or_insert_with(|| 0..)
183 .next()
184 .expect("Out of tee indices");
185
186 let my_tee_index_lit =
187 syn::LitInt::new(&format!("{}", my_tee_index), get_span(target.span));
188 let target_ident = syn::Ident::new(&target.name, get_span(target.span));
189
190 let output_pipeline: Pipeline = parse_pipeline(&hf_code.code, &get_span)?;
191 let output_pipeline = if persists.contains(&target.name) {
192 parse_quote_spanned! {get_span(target.span)=> persist::<'static>() -> #output_pipeline}
193 } else {
194 output_pipeline
195 };
196
197 flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
198 #target_ident [#my_tee_index_lit] -> #output_pipeline;
199 });
200 }
201
202 for (target, send_hf, recv_hf) in asyncs {
203 let async_send_pipeline = format!("{}_async_send", target.name);
204 let async_send_pipeline = syn::Ident::new(&async_send_pipeline, get_span(target.span));
205
206 let recv_union_index = union_counter
207 .entry(target.name.clone())
208 .or_insert_with(|| 0..)
209 .next()
210 .expect("Out of union indices");
211
212 let recv_union_index_lit =
213 syn::LitInt::new(&format!("{}", recv_union_index), get_span(target.span));
214 let target_ident =
215 syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
216
217 let send_pipeline: Pipeline = parse_pipeline(&send_hf.code, &get_span)?;
218 let recv_pipeline: Pipeline = parse_pipeline(&recv_hf.code, &get_span)?;
219
220 flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
221 #async_send_pipeline = union() -> unique::<'tick>() -> #send_pipeline;
222 });
223
224 flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
225 #recv_pipeline -> [#recv_union_index_lit] #target_ident;
226 });
227 }
228
229 for (target, hf_code) in statics {
230 let my_union_index = union_counter
231 .entry(target.name.clone())
232 .or_insert_with(|| 0..)
233 .next()
234 .expect("Out of union indices");
235
236 let my_union_index_lit =
237 syn::LitInt::new(&format!("{}", my_union_index), get_span(target.span));
238 let name = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
239
240 let static_expression: syn::Expr = parse_static(&hf_code.code, &get_span)?;
241
242 flat_graph_builder.add_statement(parse_quote_spanned! {get_span(target.span)=>
243 source_iter(#static_expression) -> persist::<'static>() -> [#my_union_index_lit] #name;
244 });
245 }
246
247 let mut next_join_idx = 0..;
248 let mut diagnostics = Vec::new();
249 for rule in rules {
250 let plan = compute_join_plan(&rule.sources, &persists);
251 generate_rule(
252 plan,
253 rule,
254 &mut flat_graph_builder,
255 &mut tee_counter,
256 &mut union_counter,
257 &mut next_join_idx,
258 &persists,
259 &mut diagnostics,
260 &get_span,
261 );
262 }
263
264 if !diagnostics.is_empty() {
265 return Err(diagnostics);
266 }
267
268 let (mut flat_graph, _uses, mut diagnostics) = flat_graph_builder.build();
269
270 diagnostics.retain(Diagnostic::is_error);
271 if !diagnostics.is_empty() {
272 return Err(diagnostics);
273 }
274
275 if let Err(err) = flat_graph.merge_modules() {
276 diagnostics.push(err);
277 return Err(diagnostics);
278 }
279
280 eliminate_extra_unions_tees(&mut flat_graph);
281 Ok(flat_graph)
282}
283
284fn handle_errors(
285 errors: Vec<ParseError>,
286 get_span: &impl Fn((usize, usize)) -> Span,
287) -> Vec<Diagnostic> {
288 let mut diagnostics = vec![];
289 for error in errors {
290 let reason = error.reason;
291 let my_span = get_span((error.start, error.end));
292 match reason {
293 ParseErrorReason::UnexpectedToken(msg) => {
294 diagnostics.push(Diagnostic::spanned(
295 my_span,
296 Level::Error,
297 format!("Unexpected Token: '{msg}'", msg = msg),
298 ));
299 }
300 ParseErrorReason::MissingToken(msg) => {
301 diagnostics.push(Diagnostic::spanned(
302 my_span,
303 Level::Error,
304 format!("Missing Token: '{msg}'", msg = msg),
305 ));
306 }
307 ParseErrorReason::FailedNode(parse_errors) => {
308 if parse_errors.is_empty() {
309 diagnostics.push(Diagnostic::spanned(
310 my_span,
311 Level::Error,
312 "Failed to parse",
313 ));
314 } else {
315 diagnostics.extend(handle_errors(parse_errors, get_span));
316 }
317 }
318 }
319 }
320
321 diagnostics
322}
323
324pub fn dfir_graph_to_program(flat_graph: DfirGraph, root: TokenStream) -> TokenStream {
325 let partitioned_graph =
326 partition_graph(flat_graph).expect("Failed to partition (cycle detected).");
327
328 let mut diagnostics = Vec::new();
329 let code_tokens = partitioned_graph.as_code(&root, true, quote::quote!(), &mut diagnostics);
330 assert_eq!(
331 0,
332 diagnostics.len(),
333 "Operator diagnostic occured during codegen"
334 );
335
336 code_tokens
337}
338
339#[expect(clippy::too_many_arguments, reason = "internal code")]
340fn generate_rule(
341 plan: JoinPlan<'_>,
342 rule: &rust_sitter::Spanned<Rule>,
343 flat_graph_builder: &mut FlatGraphBuilder,
344 tee_counter: &mut HashMap<String, Counter>,
345 union_counter: &mut HashMap<String, Counter>,
346 next_join_idx: &mut Counter,
347 persists: &HashSet<String>,
348 diagnostics: &mut Vec<Diagnostic>,
349 get_span: &impl Fn((usize, usize)) -> Span,
350) {
351 let target = &rule.target.name;
352 let target_ident = syn::Ident::new(&format!("{}_insert", target.name), get_span(target.span));
353
354 let out_expanded = join_plan::expand_join_plan(
355 &plan,
356 flat_graph_builder,
357 tee_counter,
358 next_join_idx,
359 rule.span,
360 diagnostics,
361 get_span,
362 );
363
364 let after_join = apply_aggregations(
365 rule,
366 &out_expanded,
367 persists.contains(&target.name),
368 diagnostics,
369 get_span,
370 );
371
372 let my_union_index = union_counter
373 .entry(target.name.clone())
374 .or_insert_with(|| 0..)
375 .next()
376 .expect("Out of union indices");
377
378 let my_union_index_lit = syn::LitInt::new(&format!("{}", my_union_index), Span::call_site());
379
380 let after_join_and_send: Pipeline = match rule.rule_type.value {
381 RuleType::Sync(_) => {
382 if rule.target.at_node.is_some() {
383 panic!("Rule must be async to send data to other nodes")
384 }
385
386 parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> [#my_union_index_lit] #target_ident)
387 }
388 RuleType::NextTick(_) => {
389 if rule.target.at_node.is_some() {
390 panic!("Rule must be async to send data to other nodes")
391 }
392
393 parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> defer_tick() -> [#my_union_index_lit] #target_ident)
394 }
395 RuleType::Async(_) => {
396 if rule.target.at_node.is_none() {
397 panic!("Async rules are only for sending data to other nodes")
398 }
399
400 let exprs_get_data = rule
401 .target
402 .fields
403 .iter()
404 .enumerate()
405 .map(|(i, f)| -> syn::Expr {
406 let syn_index = syn::Index::from(i);
407 parse_quote_spanned!(get_span(f.span)=> v.#syn_index)
408 });
409
410 let syn_target_index = syn::Index::from(rule.target.fields.len());
411
412 let v_type = repeat_tuple::<syn::Type, syn::Type>(
413 || parse_quote!(_),
414 rule.target.fields.len() + 1,
415 );
416
417 let send_pipeline_ident = syn::Ident::new(
418 &format!("{}_async_send", &rule.target.name.name),
419 get_span(rule.target.name.span),
420 );
421
422 parse_quote_spanned!(get_span(rule.rule_type.span)=> #after_join -> map(|v: #v_type| (v.#syn_target_index, (#(#exprs_get_data, )*))) -> #send_pipeline_ident)
423 }
424 };
425
426 let out_name = out_expanded.name;
427 let out_indexing = out_expanded.tee_idx.map(|i| Indexing {
430 bracket_token: syn::token::Bracket::default(),
431 index: PortIndex::Int(IndexInt {
432 value: i,
433 span: Span::call_site(),
434 }),
435 });
436 flat_graph_builder.add_statement(DfirStatement::Pipeline(PipelineStatement {
437 pipeline: Pipeline::Link(PipelineLink {
438 lhs: Box::new(parse_quote!(#out_name #out_indexing)), arrow: parse_quote!(->),
440 rhs: Box::new(after_join_and_send),
441 }),
442 semi_token: Token),
443 }));
444}
445
446fn compute_join_plan<'a>(sources: &'a [Atom], persisted_rules: &HashSet<String>) -> JoinPlan<'a> {
447 let mut plan: JoinPlan = sources
449 .iter()
450 .filter_map(|x| match x {
451 Atom::PosRelation(e) => {
452 if !MAGIC_RELATIONS.contains(&e.name.name.as_str()) {
453 Some(JoinPlan::Source(e, persisted_rules.contains(&e.name.name)))
454 } else {
455 None
456 }
457 }
458 _ => None,
459 })
460 .reduce(|a, b| JoinPlan::Join(Box::new(a), Box::new(b)))
461 .unwrap();
462
463 plan = sources
464 .iter()
465 .filter_map(|x| match x {
466 Atom::NegRelation(_, e) => {
467 Some(JoinPlan::Source(e, persisted_rules.contains(&e.name.name)))
468 }
469 _ => None,
470 })
471 .fold(plan, |a, b| JoinPlan::AntiJoin(Box::new(a), Box::new(b)));
472
473 let predicates = sources
474 .iter()
475 .filter_map(|x| match x {
476 Atom::Predicate(e) => Some(e),
477 _ => None,
478 })
479 .collect::<Vec<_>>();
480
481 if !predicates.is_empty() {
482 plan = JoinPlan::Predicate(predicates, Box::new(plan))
483 }
484
485 plan = sources.iter().fold(plan, |acc, atom| match atom {
486 Atom::PosRelation(e) => {
487 if MAGIC_RELATIONS.contains(&e.name.name.as_str()) {
488 match e.name.name.as_str() {
489 "less_than" => JoinPlan::MagicNatLt(
490 Box::new(acc),
491 e.fields[0].value.clone(),
492 e.fields[1].value.clone(),
493 ),
494 o => panic!("Unknown magic relation {}", o),
495 }
496 } else {
497 acc
498 }
499 }
500 _ => acc,
501 });
502
503 plan
504}
505
506pub(crate) fn gen_value_expr(
507 expr: &IntExpr,
508 lookup_ident: &mut impl FnMut(&rust_sitter::Spanned<Ident>) -> syn::Expr,
509 get_span: &dyn Fn((usize, usize)) -> Span,
510) -> syn::Expr {
511 match expr {
512 IntExpr::Ident(ident) => lookup_ident(ident),
513 IntExpr::Integer(i) => syn::Expr::Lit(syn::ExprLit {
514 attrs: Vec::new(),
515 lit: syn::Lit::Int(syn::LitInt::new(&i.to_string(), get_span(i.span))),
516 }),
517 IntExpr::Parenthesized(_, e, _) => {
518 let inner = gen_value_expr(e, lookup_ident, get_span);
519 parse_quote!((#inner))
520 }
521 IntExpr::Add(l, _, r) => {
522 let l = gen_value_expr(l, lookup_ident, get_span);
523 let r = gen_value_expr(r, lookup_ident, get_span);
524 parse_quote!(#l + #r)
525 }
526 IntExpr::Sub(l, _, r) => {
527 let l = gen_value_expr(l, lookup_ident, get_span);
528 let r = gen_value_expr(r, lookup_ident, get_span);
529 parse_quote!(#l - #r)
530 }
531 IntExpr::Mul(l, _, r) => {
532 let l = gen_value_expr(l, lookup_ident, get_span);
533 let r = gen_value_expr(r, lookup_ident, get_span);
534 parse_quote!(#l * #r)
535 }
536 IntExpr::Mod(l, _, r) => {
537 let l = gen_value_expr(l, lookup_ident, get_span);
538 let r = gen_value_expr(r, lookup_ident, get_span);
539 parse_quote!(#l % #r)
540 }
541 }
542}
543
544fn gen_target_expr(
545 expr: &TargetExpr,
546 lookup_ident: &mut impl FnMut(&rust_sitter::Spanned<Ident>) -> syn::Expr,
547 get_span: &dyn Fn((usize, usize)) -> Span,
548) -> syn::Expr {
549 match expr {
550 TargetExpr::Expr(expr) => gen_value_expr(expr, lookup_ident, get_span),
551 TargetExpr::Aggregation(Aggregation::Count(_)) => parse_quote!(()),
552 TargetExpr::Aggregation(Aggregation::CountUnique(_, _, keys, _))
553 | TargetExpr::Aggregation(Aggregation::CollectVec(_, _, keys, _)) => {
554 let keys = keys
555 .iter()
556 .map(|k| gen_value_expr(&IntExpr::Ident(k.clone()), lookup_ident, get_span))
557 .collect::<Vec<_>>();
558 parse_quote!((#(#keys),*))
559 }
560 TargetExpr::Aggregation(
561 Aggregation::Min(_, _, a, _)
562 | Aggregation::Max(_, _, a, _)
563 | Aggregation::Sum(_, _, a, _)
564 | Aggregation::Choose(_, _, a, _),
565 ) => gen_value_expr(&IntExpr::Ident(a.clone()), lookup_ident, get_span),
566 TargetExpr::Index(_, _, _) => unreachable!(),
567 }
568}
569
570fn apply_aggregations(
571 rule: &Rule,
572 out_expanded: &IntermediateJoinNode,
573 consumer_is_persist: bool,
574 diagnostics: &mut Vec<Diagnostic>,
575 get_span: &impl Fn((usize, usize)) -> Span,
576) -> Pipeline {
577 let mut field_use_count = HashMap::new();
578 for field in rule
579 .target
580 .fields
581 .iter()
582 .chain(rule.target.at_node.iter().map(|n| &n.node))
583 {
584 for ident in field.idents() {
585 field_use_count
586 .entry(ident.name.clone())
587 .and_modify(|e| *e += 1)
588 .or_insert(1);
589 }
590 }
591
592 let mut aggregations = vec![];
593 let mut fold_keyed_exprs = vec![];
594 let mut agg_exprs = vec![];
595
596 let mut field_use_cur = HashMap::new();
597 let mut has_index = false;
598
599 let mut copy_group_key_lookups: Vec<syn::Expr> = vec![];
600 let mut after_group_lookups: Vec<syn::Expr> = vec![];
601 let mut group_key_idx = 0;
602 let mut agg_idx = 0;
603
604 for field in rule
605 .target
606 .fields
607 .iter()
608 .chain(rule.target.at_node.iter().map(|n| &n.node))
609 {
610 if matches!(field.deref(), TargetExpr::Index(_, _, _)) {
611 has_index = true;
612 after_group_lookups
613 .push(parse_quote_spanned!(get_span(field.span)=> __enumerate_index));
614 } else {
615 let expr: syn::Expr = gen_target_expr(
616 field,
617 &mut |ident| {
618 if let Some(col) = out_expanded.variable_mapping.get(&ident.name) {
619 let cur_count = field_use_cur
620 .entry(ident.name.clone())
621 .and_modify(|e| *e += 1)
622 .or_insert(1);
623
624 let source_col_idx = syn::Index::from(*col);
625 let base = parse_quote_spanned!(get_span(ident.span)=> row.#source_col_idx);
626
627 if *cur_count < field_use_count[&ident.name]
628 && field_use_count[&ident.name] > 1
629 {
630 parse_quote!(#base.clone())
631 } else {
632 base
633 }
634 } else {
635 diagnostics.push(Diagnostic::spanned(
636 get_span(ident.span),
637 Level::Error,
638 format!("Could not find column {} in RHS of rule", &ident.name),
639 ));
640 parse_quote!(())
641 }
642 },
643 get_span,
644 );
645
646 match &field.value {
647 TargetExpr::Expr(_) => {
648 fold_keyed_exprs.push(expr);
649
650 let idx = syn::Index::from(group_key_idx);
651 after_group_lookups.push(parse_quote_spanned!(get_span(field.span)=> g.#idx));
652 copy_group_key_lookups
653 .push(parse_quote_spanned!(get_span(field.span)=> g.#idx));
654 group_key_idx += 1;
655 }
656 TargetExpr::Aggregation(a) => {
657 aggregations.push(a.clone());
658 agg_exprs.push(expr);
659
660 match a {
661 Aggregation::CountUnique(..) => {
662 let idx = syn::Index::from(agg_idx);
663 after_group_lookups.push(
664 parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap().1),
665 );
666 agg_idx += 1;
667 }
668 Aggregation::CollectVec(..) => {
669 let idx = syn::Index::from(agg_idx);
670 after_group_lookups
671 .push(parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap().into_iter().collect::<Vec<_>>()));
672 agg_idx += 1;
673 }
674 _ => {
675 let idx = syn::Index::from(agg_idx);
676 after_group_lookups
677 .push(parse_quote_spanned!(get_span(field.span)=> a.#idx.unwrap()));
678 agg_idx += 1;
679 }
680 }
681 }
682 TargetExpr::Index(_, _, _) => unreachable!(),
683 }
684 }
685 }
686
687 let flattened_tuple_type = &out_expanded.tuple_type;
688
689 let fold_keyed_input_type =
690 repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(_), fold_keyed_exprs.len());
691
692 let after_group_pipeline: Pipeline = if has_index {
693 if out_expanded.persisted && agg_exprs.is_empty() {
694 parse_quote!(enumerate::<'static>() -> map(|(__enumerate_index, (g, a)): (_, (#fold_keyed_input_type, _))| (#(#after_group_lookups, )*)))
696 } else {
697 parse_quote!(enumerate::<'tick>() -> map(|(__enumerate_index, (g, a)): (_, (#fold_keyed_input_type, _))| (#(#after_group_lookups, )*)))
698 }
699 } else {
700 parse_quote!(map(|(g, a): (#fold_keyed_input_type, _)| (#(#after_group_lookups, )*)))
701 };
702
703 let pre_fold_keyed_map: Pipeline = parse_quote!(map(|row: #flattened_tuple_type| ((#(#fold_keyed_exprs, )*), (#(#agg_exprs, )*))));
704
705 if agg_exprs.is_empty() {
706 if out_expanded.persisted && !consumer_is_persist {
707 parse_quote!(#pre_fold_keyed_map -> #after_group_pipeline -> persist::<'static>())
708 } else {
709 parse_quote!(#pre_fold_keyed_map -> #after_group_pipeline)
710 }
711 } else {
712 let agg_initial =
713 repeat_tuple::<syn::Expr, syn::Expr>(|| parse_quote!(None), agg_exprs.len());
714
715 let agg_input_type =
716 repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(_), agg_exprs.len());
717 let agg_type: syn::Type =
718 repeat_tuple::<syn::Type, syn::Type>(|| parse_quote!(Option<_>), agg_exprs.len());
719
720 let fold_keyed_stmts: Vec<syn::Stmt> = aggregations
721 .iter()
722 .enumerate()
723 .map(|(i, agg)| {
724 let idx = syn::Index::from(i);
725 let old_at_index: syn::Expr = parse_quote!(old.#idx);
726 let val_at_index: syn::Expr = parse_quote!(val.#idx);
727
728 let agg_expr: syn::Expr = match &agg {
729 Aggregation::Min(..) => {
730 parse_quote!(std::cmp::min(prev, #val_at_index))
731 }
732 Aggregation::Max(..) => {
733 parse_quote!(std::cmp::max(prev, #val_at_index))
734 }
735 Aggregation::Sum(..) => {
736 parse_quote!(prev + #val_at_index)
737 }
738 Aggregation::Count(..) => {
739 parse_quote!(prev + 1)
740 }
741 Aggregation::CountUnique(..) => {
742 parse_quote!({
743 let prev: (dfir_rs::rustc_hash::FxHashSet<_>, _) = prev;
744 let mut set: dfir_rs::rustc_hash::FxHashSet<_> = prev.0;
745 if set.insert(#val_at_index) {
746 (set, prev.1 + 1)
747 } else {
748 (set, prev.1)
749 }
750 })
751 }
752 Aggregation::CollectVec(..) => {
753 parse_quote!({
754 let mut set: dfir_rs::rustc_hash::FxHashSet<_> = prev;
755 set.insert(#val_at_index);
756 set
757 })
758 }
759 Aggregation::Choose(..) => {
760 parse_quote!(prev) }
762 };
763
764 let agg_initial: syn::Expr = match &agg {
765 Aggregation::Min(..)
766 | Aggregation::Max(..)
767 | Aggregation::Sum(..)
768 | Aggregation::Choose(..) => {
769 parse_quote!(#val_at_index)
770 }
771 Aggregation::Count(..) => {
772 parse_quote!(1)
773 }
774 Aggregation::CountUnique(..) => {
775 parse_quote!({
776 let mut set = dfir_rs::rustc_hash::FxHashSet::<_>::default();
777 set.insert(#val_at_index);
778 (set, 1)
779 })
780 }
781 Aggregation::CollectVec(..) => {
782 parse_quote!({
783 let mut set = dfir_rs::rustc_hash::FxHashSet::<_>::default();
784 set.insert(#val_at_index);
785 set
786 })
787 }
788 };
789
790 parse_quote! {
791 #old_at_index = if let Some(prev) = #old_at_index.take() {
792 Some(#agg_expr)
793 } else {
794 Some(#agg_initial)
795 };
796 }
797 })
798 .collect();
799
800 let fold_keyed_fn: syn::Expr = parse_quote!(|old: &mut #agg_type, val: #agg_input_type| {
801 #(#fold_keyed_stmts)*
802 });
803
804 if out_expanded.persisted {
805 parse_quote! {
806 #pre_fold_keyed_map -> fold_keyed::<'static, #fold_keyed_input_type, #agg_type>(|| #agg_initial, #fold_keyed_fn) -> #after_group_pipeline
807 }
808 } else {
809 parse_quote! {
810 #pre_fold_keyed_map -> fold_keyed::<'tick, #fold_keyed_input_type, #agg_type>(|| #agg_initial, #fold_keyed_fn) -> #after_group_pipeline
811 }
812 }
813 }
814}
815
816#[cfg(test)]
817mod tests {
818 use syn::parse_quote;
819
820 use super::gen_dfir_graph;
821
822 macro_rules! test_snapshots {
823 ($program:literal) => {
824 let flat_graph = gen_dfir_graph(parse_quote!($program)).unwrap();
825
826 let flat_graph_ref = &flat_graph;
827 insta::with_settings!({snapshot_suffix => "surface_graph"}, {
828 insta::assert_snapshot!(flat_graph_ref.surface_syntax_string());
829 });
830 };
831 }
832
833 #[test]
834 fn minimal_program() {
835 test_snapshots!(
836 r#"
837 .input input `source_stream(input)`
838 .output out `for_each(|v| out.send(v).unwrap())`
839
840 out(y, x) :- input(x, y).
841 "#
842 );
843 }
844
845 #[test]
846 fn join_with_self() {
847 test_snapshots!(
848 r#"
849 .input input `source_stream(input)`
850 .output out `for_each(|v| out.send(v).unwrap())`
851
852 out(x, y) :- input(x, y), input(y, x).
853 "#
854 );
855 }
856
857 #[test]
858 fn wildcard_fields() {
859 test_snapshots!(
860 r#"
861 .input input `source_stream(input)`
862 .output out `for_each(|v| out.send(v).unwrap())`
863
864 out(x) :- input(x, _), input(_, x).
865 "#
866 );
867 }
868
869 #[test]
870 fn join_with_other() {
871 test_snapshots!(
872 r#"
873 .input in1 `source_stream(in1)`
874 .input in2 `source_stream(in2)`
875 .output out `for_each(|v| out.send(v).unwrap())`
876
877 out(x, y) :- in1(x, y), in2(y, x).
878 "#
879 );
880 }
881
882 #[test]
883 fn multiple_contributors() {
884 test_snapshots!(
885 r#"
886 .input in1 `source_stream(in1)`
887 .input in2 `source_stream(in2)`
888 .output out `for_each(|v| out.send(v).unwrap())`
889
890 out(x, y) :- in1(x, y).
891 out(x, y) :- in2(y, x).
892 "#
893 );
894 }
895
896 #[test]
897 fn transitive_closure() {
898 test_snapshots!(
899 r#"
900 .input edges `source_stream(edges)`
901 .input seed_reachable `source_stream(seed_reachable)`
902 .output reachable `for_each(|v| reachable.send(v).unwrap())`
903
904 reachable(x) :- seed_reachable(x).
905 reachable(y) :- reachable(x), edges(x, y).
906 "#
907 );
908 }
909
910 #[test]
911 fn single_column_program() {
912 test_snapshots!(
913 r#"
914 .input in1 `source_stream(in1)`
915 .input in2 `source_stream(in2)`
916 .output out `for_each(|v| out.send(v).unwrap())`
917
918 out(x) :- in1(x), in2(x).
919 "#
920 );
921 }
922
923 #[test]
924 fn triple_relation_join() {
925 test_snapshots!(
926 r#"
927 .input in1 `source_stream(in1)`
928 .input in2 `source_stream(in2)`
929 .input in3 `source_stream(in3)`
930 .output out `for_each(|v| out.send(v).unwrap())`
931
932 out(d, c, b, a) :- in1(a, b), in2(b, c), in3(c, d).
933 "#
934 );
935 }
936
937 #[test]
938 fn local_constraints() {
939 test_snapshots!(
940 r#"
941 .input input `source_stream(input)`
942 .output out `for_each(|v| out.send(v).unwrap())`
943
944 out(x, x) :- input(x, x).
945 "#
946 );
947
948 test_snapshots!(
949 r#"
950 .input input `source_stream(input)`
951 .output out `for_each(|v| out.send(v).unwrap())`
952
953 out(x, x, y, y) :- input(x, x, y, y).
954 "#
955 );
956 }
957
958 #[test]
959 fn test_simple_filter() {
960 test_snapshots!(
961 r#"
962 .input input `source_stream(input)`
963 .output out `for_each(|v| out.send(v).unwrap())`
964
965 out(x, y) :- input(x, y), ( x > y ), ( y == x ).
966 "#
967 );
968 }
969
970 #[test]
971 fn test_anti_join() {
972 test_snapshots!(
973 r#"
974 .input ints_1 `source_stream(ints_1)`
975 .input ints_2 `source_stream(ints_2)`
976 .input ints_3 `source_stream(ints_3)`
977 .output result `for_each(|v| result.send(v).unwrap())`
978
979 result(x, z) :- ints_1(x, y), ints_2(y, z), !ints_3(y)
980 "#
981 );
982 }
983
984 #[test]
985 fn test_max() {
986 test_snapshots!(
987 r#"
988 .input ints `source_stream(ints)`
989 .output result `for_each(|v| result.send(v).unwrap())`
990
991 result(max(a), b) :- ints(a, b)
992 "#
993 );
994 }
995
996 #[test]
997 fn test_max_all() {
998 test_snapshots!(
999 r#"
1000 .input ints `source_stream(ints)`
1001 .output result `for_each(|v| result.send(v).unwrap())`
1002
1003 result(max(a), max(b)) :- ints(a, b)
1004 "#
1005 );
1006 }
1007
1008 #[test]
1009 fn test_send_to_node() {
1010 test_snapshots!(
1011 r#"
1012 .input ints `source_stream(ints)`
1013 .output result `for_each(|v| result.send(v).unwrap())`
1014 .async result `for_each(|(node, data)| async_send_result(node, data))` `source_stream(async_receive_result)`
1015
1016 result@b(a) :~ ints(a, b)
1017 "#
1018 );
1019 }
1020
1021 #[test]
1022 fn test_aggregations_and_comments() {
1023 test_snapshots!(
1024 r#"
1025 # david doesn't think this line of code will execute
1026 .input ints `source_stream(ints)`
1027 .output result `for_each(|v| result.send(v).unwrap())`
1028 .output result2 `for_each(|v| result2.send(v).unwrap())`
1029
1030 result(count(a), b) :- ints(a, b)
1031 result(sum(a), b) :+ ints(a, b)
1032 result2(choose(a), b) :- ints(a, b)
1033 "#
1034 );
1035 }
1036
1037 #[test]
1038 fn test_aggregations_fold_keyed_expr() {
1039 test_snapshots!(
1040 r#"
1041 .input ints `source_stream(ints)`
1042 .output result `for_each(|v| result.send(v).unwrap())`
1043
1044 result(a % 2, sum(b)) :- ints(a, b)
1045 "#
1046 );
1047 }
1048
1049 #[test]
1050 fn test_non_copy_but_clone() {
1051 test_snapshots!(
1052 r#"
1053 .input strings `source_stream(strings)`
1054 .output result `for_each(|v| result.send(v).unwrap())`
1055
1056 result(a, a) :- strings(a)
1057 "#
1058 );
1059 }
1060
1061 #[test]
1062 fn test_expr_lhs() {
1063 test_snapshots!(
1064 r#"
1065 .input ints `source_stream(ints)`
1066 .output result `for_each(|v| result.send(v).unwrap())`
1067
1068 result(123) :- ints(a)
1069 result(a + 123) :- ints(a)
1070 result(a + a) :- ints(a)
1071 result(123 - a) :- ints(a)
1072 result(123 % (a + 5)) :- ints(a)
1073 result(a * 5) :- ints(a)
1074 "#
1075 );
1076 }
1077
1078 #[test]
1079 fn test_expr_predicate() {
1080 test_snapshots!(
1081 r#"
1082 .input ints `source_stream(ints)`
1083 .output result `for_each(|v| result.send(v).unwrap())`
1084
1085 result(1) :- ints(a), (a == 0)
1086 result(2) :- ints(a), (a != 0)
1087 result(3) :- ints(a), (a - 1 == 0)
1088 result(4) :- ints(a), (a - 1 == 1 - 1)
1089 "#
1090 );
1091 }
1092
1093 #[test]
1094 fn test_persist() {
1095 test_snapshots!(
1096 r#"
1097 .input ints1 `source_stream(ints1)`
1098 .persist ints1
1099
1100 .input ints2 `source_stream(ints2)`
1101 .persist ints2
1102
1103 .input ints3 `source_stream(ints3)`
1104
1105 .output result `for_each(|v| result.send(v).unwrap())`
1106 .output result2 `for_each(|v| result2.send(v).unwrap())`
1107 .output result3 `for_each(|v| result3.send(v).unwrap())`
1108 .output result4 `for_each(|v| result4.send(v).unwrap())`
1109
1110 result(a, b, c) :- ints1(a), ints2(b), ints3(c)
1111 result2(a) :- ints1(a), !ints2(a)
1112
1113 intermediate(a) :- ints1(a)
1114 result3(a) :- intermediate(a)
1115
1116 .persist intermediate_persist
1117 intermediate_persist(a) :- ints1(a)
1118 result4(a) :- intermediate_persist(a)
1119 "#
1120 );
1121 }
1122
1123 #[test]
1124 fn test_persist_uniqueness() {
1125 test_snapshots!(
1126 r#"
1127 .persist ints1
1128
1129 .input ints2 `source_stream(ints2)`
1130
1131 ints1(a) :- ints2(a)
1132
1133 .output result `for_each(|v| result.send(v).unwrap())`
1134
1135 result(count(a)) :- ints1(a)
1136 "#
1137 );
1138 }
1139
1140 #[test]
1141 fn test_wildcard_join_count() {
1142 test_snapshots!(
1143 r#"
1144 .input ints1 `source_stream(ints1)`
1145 .input ints2 `source_stream(ints2)`
1146
1147 .output result `for_each(|v| result.send(v).unwrap())`
1148 .output result2 `for_each(|v| result2.send(v).unwrap())`
1149
1150 result(count(*)) :- ints1(a, _), ints2(a)
1151 result2(count(a)) :- ints1(a, _), ints2(a)
1152 "#
1153 );
1154 }
1155
1156 #[test]
1157 fn test_index() {
1158 test_snapshots!(
1159 r#"
1160 .input ints `source_stream(ints)`
1161
1162 .output result `for_each(|v| result.send(v).unwrap())`
1163 .output result2 `for_each(|v| result2.send(v).unwrap())`
1164 .output result3 `for_each(|v| result3.send(v).unwrap())`
1165 .output result4 `for_each(|v| result4.send(v).unwrap())`
1166
1167 .persist result5
1168 .output result5 `for_each(|v| result5.send(v).unwrap())`
1169
1170 result(a, b, index()) :- ints(a, b)
1171 result2(a, count(b), index()) :- ints(a, b)
1172
1173 .persist ints_persisted
1174 ints_persisted(a, b) :- ints(a, b)
1175
1176 result3(a, b, index()) :- ints_persisted(a, b)
1177 result4(a, count(b), index()) :- ints_persisted(a, b)
1178 result5(a, b, index()) :- ints_persisted(a, b)
1179 "#
1180 );
1181 }
1182
1183 #[test]
1184 fn test_collect_vec() {
1185 test_snapshots!(
1186 r#"
1187 .input ints1 `source_stream(ints1)`
1188 .input ints2 `source_stream(ints2)`
1189
1190 .output result `for_each(|v| result.send(v).unwrap())`
1191
1192 result(collect_vec(a, b)) :- ints1(a), ints2(b)
1193 "#
1194 );
1195 }
1196
1197 #[test]
1198 fn test_flatten() {
1199 test_snapshots!(
1200 r#"
1201 .input ints1 `source_stream(ints1)`
1202
1203 .output result `for_each(|v| result.send(v).unwrap())`
1204
1205 result(a, b) :- ints1(a, *b)
1206 "#
1207 );
1208 }
1209
1210 #[test]
1211 fn test_detuple() {
1212 test_snapshots!(
1213 r#"
1214 .input ints1 `source_stream(ints1)`
1215
1216 .output result `for_each(|v| result.send(v).unwrap())`
1217
1218 result(a, b) :- ints1((a, b))
1219 "#
1220 );
1221 }
1222
1223 #[test]
1224 fn test_multi_detuple() {
1225 test_snapshots!(
1226 r#"
1227 .input ints1 `source_stream(ints1)`
1228
1229 .output result `for_each(|v| result.send(v).unwrap())`
1230
1231 result(a, b, c, d) :- ints1((a, b), (c, d))
1232 "#
1233 );
1234 }
1235
1236 #[test]
1237 fn test_flat_then_detuple() {
1238 test_snapshots!(
1239 r#"
1240 .input ints1 `source_stream(ints1)`
1241
1242 .output result `for_each(|v| result.send(v).unwrap())`
1243
1244 result(a, b) :- ints1(*(a, b))
1245 "#
1246 );
1247 }
1248
1249 #[test]
1250 fn test_detuple_then_flat() {
1251 test_snapshots!(
1252 r#"
1253 .input ints1 `source_stream(ints1)`
1254
1255 .output result `for_each(|v| result.send(v).unwrap())`
1256
1257 result(a, b) :- ints1((*a, *b))
1258 "#
1259 );
1260 }
1261}