1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn =
91 fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
92
93impl Debug for OperatorConstraints {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("OperatorConstraints")
96 .field("name", &self.name)
97 .field("hard_range_inn", &self.hard_range_inn)
98 .field("soft_range_inn", &self.soft_range_inn)
99 .field("hard_range_out", &self.hard_range_out)
100 .field("soft_range_out", &self.soft_range_out)
101 .field("num_args", &self.num_args)
102 .field("persistence_args", &self.persistence_args)
103 .field("type_args", &self.type_args)
104 .field("is_external_input", &self.is_external_input)
105 .field("ports_inn", &self.ports_inn)
106 .field("ports_out", &self.ports_out)
107 .finish()
111 }
112}
113
114#[derive(Default)]
116#[non_exhaustive]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
122 pub write_prologue_after: TokenStream,
125 pub write_iterator: TokenStream,
132 pub write_iterator_after: TokenStream,
134}
135
136pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
138pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
140pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
142
143pub fn identity_write_iterator_fn(
146 &WriteContextArgs {
147 root,
148 op_span,
149 ident,
150 inputs,
151 outputs,
152 is_pull,
153 op_inst:
154 OperatorInstance {
155 generics: OpInstGenerics { type_args, .. },
156 ..
157 },
158 ..
159 }: &WriteContextArgs,
160) -> TokenStream {
161 let generic_type = type_args
162 .first()
163 .map(quote::ToTokens::to_token_stream)
164 .unwrap_or(quote_spanned!(op_span=> _));
165
166 if is_pull {
167 let input = &inputs[0];
168 quote_spanned! {op_span=>
169 let #ident = {
170 fn check_input<Iter: ::std::iter::Iterator<Item = Item>, Item>(iter: Iter) -> impl ::std::iter::Iterator<Item = Item> { iter }
171 check_input::<_, #generic_type>(#input)
172 };
173 }
174 } else {
175 let output = &outputs[0];
176 quote_spanned! {op_span=>
177 let #ident = {
178 fn check_output<Push: #root::pusherator::Pusherator<Item = Item>, Item>(push: Push) -> impl #root::pusherator::Pusherator<Item = Item> { push }
179 check_output::<_, #generic_type>(#output)
180 };
181 }
182 }
183}
184
185pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
187 let write_iterator = identity_write_iterator_fn(write_context_args);
188 Ok(OperatorWriteOutput {
189 write_iterator,
190 ..Default::default()
191 })
192};
193
194pub fn null_write_iterator_fn(
197 &WriteContextArgs {
198 root,
199 op_span,
200 ident,
201 inputs,
202 outputs,
203 is_pull,
204 op_inst:
205 OperatorInstance {
206 generics: OpInstGenerics { type_args, .. },
207 ..
208 },
209 ..
210 }: &WriteContextArgs,
211) -> TokenStream {
212 let default_type = parse_quote_spanned! {op_span=> _};
213 let iter_type = type_args.first().unwrap_or(&default_type);
214
215 if is_pull {
216 quote_spanned! {op_span=>
217 #(
218 #inputs.for_each(std::mem::drop);
219 )*
220 let #ident = std::iter::empty::<#iter_type>();
221 }
222 } else {
223 quote_spanned! {op_span=>
224 #[allow(clippy::let_unit_value)]
225 let _ = (#(#outputs),*);
226 let #ident = #root::pusherator::null::Null::<#iter_type>::new();
227 }
228 }
229}
230
231pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
234 let write_iterator = null_write_iterator_fn(write_context_args);
235 Ok(OperatorWriteOutput {
236 write_iterator,
237 ..Default::default()
238 })
239};
240
241macro_rules! declare_ops {
242 ( $( $mod:ident :: $op:ident, )* ) => {
243 $( pub(crate) mod $mod; )*
244 pub const OPERATORS: &[OperatorConstraints] = &[
246 $( $mod :: $op, )*
247 ];
248 };
249}
250declare_ops![
251 all_iterations::ALL_ITERATIONS,
252 all_once::ALL_ONCE,
253 anti_join::ANTI_JOIN,
254 anti_join_multiset::ANTI_JOIN_MULTISET,
255 assert::ASSERT,
256 assert_eq::ASSERT_EQ,
257 batch::BATCH,
258 chain::CHAIN,
259 _counter::_COUNTER,
260 cross_join::CROSS_JOIN,
261 cross_join_multiset::CROSS_JOIN_MULTISET,
262 cross_singleton::CROSS_SINGLETON,
263 demux::DEMUX,
264 demux_enum::DEMUX_ENUM,
265 dest_file::DEST_FILE,
266 dest_sink::DEST_SINK,
267 dest_sink_serde::DEST_SINK_SERDE,
268 difference::DIFFERENCE,
269 difference_multiset::DIFFERENCE_MULTISET,
270 enumerate::ENUMERATE,
271 filter::FILTER,
272 filter_map::FILTER_MAP,
273 flat_map::FLAT_MAP,
274 flatten::FLATTEN,
275 fold::FOLD,
276 for_each::FOR_EACH,
277 identity::IDENTITY,
278 initialize::INITIALIZE,
279 inspect::INSPECT,
280 join::JOIN,
281 join_fused::JOIN_FUSED,
282 join_fused_lhs::JOIN_FUSED_LHS,
283 join_fused_rhs::JOIN_FUSED_RHS,
284 join_multiset::JOIN_MULTISET,
285 fold_keyed::FOLD_KEYED,
286 reduce_keyed::REDUCE_KEYED,
287 repeat_n::REPEAT_N,
288 lattice_bimorphism::LATTICE_BIMORPHISM,
290 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
291 lattice_fold::LATTICE_FOLD,
292 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
293 lattice_reduce::LATTICE_REDUCE,
294 map::MAP,
295 union::UNION,
296 multiset_delta::MULTISET_DELTA,
297 next_iteration::NEXT_ITERATION,
298 next_stratum::NEXT_STRATUM,
299 defer_signal::DEFER_SIGNAL,
300 defer_tick::DEFER_TICK,
301 defer_tick_lazy::DEFER_TICK_LAZY,
302 null::NULL,
303 partition::PARTITION,
304 persist::PERSIST,
305 persist_mut::PERSIST_MUT,
306 persist_mut_keyed::PERSIST_MUT_KEYED,
307 prefix::PREFIX,
308 resolve_futures::RESOLVE_FUTURES,
309 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
310 py_udf::PY_UDF,
311 reduce::REDUCE,
312 spin::SPIN,
313 sort::SORT,
314 sort_by_key::SORT_BY_KEY,
315 source_file::SOURCE_FILE,
316 source_interval::SOURCE_INTERVAL,
317 source_iter::SOURCE_ITER,
318 source_json::SOURCE_JSON,
319 source_stdin::SOURCE_STDIN,
320 source_stream::SOURCE_STREAM,
321 source_stream_serde::SOURCE_STREAM_SERDE,
322 state::STATE,
323 state_by::STATE_BY,
324 tee::TEE,
325 unique::UNIQUE,
326 unzip::UNZIP,
327 zip::ZIP,
328 zip_longest::ZIP_LONGEST,
329];
330
331pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
333 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
334 OnceLock::new();
335 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
336}
337pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
339 if let GraphNode::Operator(operator) = node {
340 find_op_op_constraints(operator)
341 } else {
342 None
343 }
344}
345pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
347 let name = &*operator.name_string();
348 operator_lookup().get(name).copied()
349}
350
351#[derive(Clone)]
353pub struct WriteContextArgs<'a> {
354 pub root: &'a TokenStream,
356 pub context: &'a Ident,
359 pub df_ident: &'a Ident,
363 pub subgraph_id: GraphSubgraphId,
365 pub node_id: GraphNodeId,
367 pub loop_id: Option<GraphLoopId>,
369 pub op_span: Span,
371 pub op_tag: Option<String>,
373 pub work_fn: &'a Ident,
375
376 pub ident: &'a Ident,
378 pub is_pull: bool,
380 pub inputs: &'a [Ident],
382 pub outputs: &'a [Ident],
384 pub singleton_output_ident: &'a Ident,
386
387 pub op_name: &'static str,
389 pub op_inst: &'a OperatorInstance,
391 pub arguments: &'a Punctuated<Expr, Token![,]>,
397 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
399}
400impl WriteContextArgs<'_> {
401 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
407 Ident::new(
408 &format!(
409 "sg_{:?}_node_{:?}_{}",
410 self.subgraph_id.data(),
411 self.node_id.data(),
412 suffix.as_ref(),
413 ),
414 self.op_span,
415 )
416 }
417
418 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
421 let root = self.root;
422 let variant =
423 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
424 Some(quote_spanned! {self.op_span=>
425 #root::scheduled::graph::StateLifespan::#variant
426 })
427 }
428
429 pub fn persistence_args_disallow_mutable<const N: usize>(
431 &self,
432 diagnostics: &mut Vec<Diagnostic>,
433 ) -> [Persistence; N] {
434 let len = self.op_inst.generics.persistence_args.len();
435 if 0 != len && 1 != len && N != len {
436 diagnostics.push(Diagnostic::spanned(
437 self.op_span,
438 Level::Error,
439 format!(
440 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
441 self.op_name, N
442 ),
443 ));
444 }
445
446 let default_persistence = if self.loop_id.is_some() {
447 Persistence::None
448 } else {
449 Persistence::Tick
450 };
451 let mut out = [default_persistence; N];
452 self.op_inst
453 .generics
454 .persistence_args
455 .iter()
456 .copied()
457 .cycle() .take(N)
459 .enumerate()
460 .filter(|&(_i, p)| {
461 if p == Persistence::Mutable {
462 diagnostics.push(Diagnostic::spanned(
463 self.op_span,
464 Level::Error,
465 format!(
466 "An implementation of `'{}` does not exist",
467 p.to_str_lowercase()
468 ),
469 ));
470 false
471 } else {
472 true
473 }
474 })
475 .for_each(|(i, p)| {
476 out[i] = p;
477 });
478 out
479 }
480}
481
482pub trait RangeTrait<T>: Send + Sync + Debug
484where
485 T: ?Sized,
486{
487 fn start_bound(&self) -> Bound<&T>;
489 fn end_bound(&self) -> Bound<&T>;
491 fn contains(&self, item: &T) -> bool
493 where
494 T: PartialOrd<T>;
495
496 fn human_string(&self) -> String
498 where
499 T: Display + PartialEq,
500 {
501 match (self.start_bound(), self.end_bound()) {
502 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
503
504 (Bound::Included(n), Bound::Included(x)) if n == x => {
505 format!("exactly {}", n)
506 }
507 (Bound::Included(n), Bound::Included(x)) => {
508 format!("at least {} and at most {}", n, x)
509 }
510 (Bound::Included(n), Bound::Excluded(x)) => {
511 format!("at least {} and less than {}", n, x)
512 }
513 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
514 (Bound::Excluded(n), Bound::Included(x)) => {
515 format!("more than {} and at most {}", n, x)
516 }
517 (Bound::Excluded(n), Bound::Excluded(x)) => {
518 format!("more than {} and less than {}", n, x)
519 }
520 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
521 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
522 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
523 }
524 }
525}
526
527impl<R, T> RangeTrait<T> for R
528where
529 R: RangeBounds<T> + Send + Sync + Debug,
530{
531 fn start_bound(&self) -> Bound<&T> {
532 self.start_bound()
533 }
534
535 fn end_bound(&self) -> Bound<&T> {
536 self.end_bound()
537 }
538
539 fn contains(&self, item: &T) -> bool
540 where
541 T: PartialOrd<T>,
542 {
543 self.contains(item)
544 }
545}
546
547#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
549pub enum Persistence {
550 None,
552 Loop,
554 Tick,
556 Static,
558 Mutable,
560}
561impl Persistence {
562 pub fn as_state_lifespan_variant(
564 self,
565 subgraph_id: GraphSubgraphId,
566 loop_id: Option<GraphLoopId>,
567 span: Span,
568 ) -> Option<TokenStream> {
569 match self {
570 Persistence::None => {
571 let sg_ident = subgraph_id.as_ident(span);
572 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
573 }
574 Persistence::Loop => {
575 let loop_ident = loop_id
576 .expect("`Persistence::Loop` outside of a loop context.")
577 .as_ident(span);
578 Some(quote_spanned!(span=> Loop(#loop_ident)))
579 }
580 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
581 Persistence::Static => None,
582 Persistence::Mutable => None,
583 }
584 }
585
586 pub fn to_str_lowercase(self) -> &'static str {
588 match self {
589 Persistence::None => "none",
590 Persistence::Tick => "tick",
591 Persistence::Loop => "loop",
592 Persistence::Static => "static",
593 Persistence::Mutable => "mutable",
594 }
595 }
596}
597
598fn make_missing_runtime_msg(op_name: &str) -> Literal {
600 Literal::string(&format!(
601 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
602 op_name
603 ))
604}
605
606#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
608pub enum OperatorCategory {
609 Map,
611 Filter,
613 Flatten,
615 Fold,
617 KeyedFold,
619 LatticeFold,
621 Persistence,
623 MultiIn,
625 MultiOut,
627 Source,
629 Sink,
631 Control,
633 CompilerFusionOperator,
635 Windowing,
637 Unwindowing,
639}
640impl OperatorCategory {
641 pub fn name(self) -> &'static str {
643 self.get_variant_docs().split_once(":").unwrap().0
644 }
645 pub fn description(self) -> &'static str {
647 self.get_variant_docs().split_once(":").unwrap().1
648 }
649}
650
651#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
653pub enum FloType {
654 Source,
656 Windowing,
658 Unwindowing,
660 NextIteration,
662}