1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn =
91 fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
92
93impl Debug for OperatorConstraints {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct("OperatorConstraints")
96 .field("name", &self.name)
97 .field("hard_range_inn", &self.hard_range_inn)
98 .field("soft_range_inn", &self.soft_range_inn)
99 .field("hard_range_out", &self.hard_range_out)
100 .field("soft_range_out", &self.soft_range_out)
101 .field("num_args", &self.num_args)
102 .field("persistence_args", &self.persistence_args)
103 .field("type_args", &self.type_args)
104 .field("is_external_input", &self.is_external_input)
105 .field("ports_inn", &self.ports_inn)
106 .field("ports_out", &self.ports_out)
107 .finish()
111 }
112}
113
114#[derive(Default)]
116#[non_exhaustive]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
122 pub write_prologue_after: TokenStream,
125 pub write_iterator: TokenStream,
132 pub write_iterator_after: TokenStream,
134}
135
136pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
138pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
140pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
142
143pub fn identity_write_iterator_fn(
146 &WriteContextArgs {
147 root,
148 op_span,
149 ident,
150 inputs,
151 outputs,
152 is_pull,
153 op_inst:
154 OperatorInstance {
155 generics: OpInstGenerics { type_args, .. },
156 ..
157 },
158 ..
159 }: &WriteContextArgs,
160) -> TokenStream {
161 let generic_type = type_args
162 .first()
163 .map(quote::ToTokens::to_token_stream)
164 .unwrap_or(quote_spanned!(op_span=> _));
165
166 if is_pull {
167 let input = &inputs[0];
168 quote_spanned! {op_span=>
169 let #ident = {
170 fn check_input<St, Item>(stream: St) -> impl #root::futures::stream::Stream<Item = Item>
171 where
172 St: #root::futures::stream::Stream<Item = Item>,
173 {
174 stream
175 }
176 check_input::<_, #generic_type>(#input)
177 };
178 }
179 } else {
180 let output = &outputs[0];
181 quote_spanned! {op_span=>
182 let #ident = {
183 fn check_output<Si, Item>(sink: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
184 where
185 Si: #root::futures::sink::Sink<Item, Error = #root::Never>,
186 {
187 sink
188 }
189 check_output::<_, #generic_type>(#output)
190 };
191 }
192 }
193}
194
195pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
197 let write_iterator = identity_write_iterator_fn(write_context_args);
198 Ok(OperatorWriteOutput {
199 write_iterator,
200 ..Default::default()
201 })
202};
203
204pub fn null_write_iterator_fn(
207 &WriteContextArgs {
208 root,
209 op_span,
210 ident,
211 inputs,
212 outputs,
213 is_pull,
214 op_inst:
215 OperatorInstance {
216 generics: OpInstGenerics { type_args, .. },
217 ..
218 },
219 ..
220 }: &WriteContextArgs,
221) -> TokenStream {
222 let default_type = parse_quote_spanned! {op_span=> _};
223 let iter_type = type_args.first().unwrap_or(&default_type);
224
225 if is_pull {
226 quote_spanned! {op_span=>
227 let #ident = #root::futures::stream::poll_fn(move |_cx| {
228 #(
230 let #inputs = #root::futures::stream::Stream::poll_next(::std::pin::pin!(#inputs), _cx);
231 )*
232 #(
233 let _ = ::std::task::ready!(#inputs);
234 )*
235 ::std::task::Poll::Ready(::std::option::Option::None)
236 });
237 }
238 } else {
239 quote_spanned! {op_span=>
240 #[allow(clippy::let_unit_value)]
241 let _ = (#(#outputs),*);
242 let #ident = #root::sinktools::for_each::ForEach::new::<#iter_type>(::std::mem::drop::<#iter_type>);
243 }
244 }
245}
246
247pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
250 let write_iterator = null_write_iterator_fn(write_context_args);
251 Ok(OperatorWriteOutput {
252 write_iterator,
253 ..Default::default()
254 })
255};
256
257macro_rules! declare_ops {
258 ( $( $mod:ident :: $op:ident, )* ) => {
259 $( pub(crate) mod $mod; )*
260 pub const OPERATORS: &[OperatorConstraints] = &[
262 $( $mod :: $op, )*
263 ];
264 };
265}
266declare_ops![
267 all_iterations::ALL_ITERATIONS,
268 all_once::ALL_ONCE,
269 anti_join::ANTI_JOIN,
270 assert::ASSERT,
271 assert_eq::ASSERT_EQ,
272 batch::BATCH,
273 chain::CHAIN,
274 chain_first_n::CHAIN_FIRST_N,
275 _counter::_COUNTER,
276 cross_join::CROSS_JOIN,
277 cross_join_multiset::CROSS_JOIN_MULTISET,
278 cross_singleton::CROSS_SINGLETON,
279 demux_enum::DEMUX_ENUM,
280 dest_file::DEST_FILE,
281 dest_sink::DEST_SINK,
282 dest_sink_serde::DEST_SINK_SERDE,
283 difference::DIFFERENCE,
284 enumerate::ENUMERATE,
285 filter::FILTER,
286 filter_map::FILTER_MAP,
287 flat_map::FLAT_MAP,
288 flatten::FLATTEN,
289 fold::FOLD,
290 fold_no_replay::FOLD_NO_REPLAY,
291 for_each::FOR_EACH,
292 identity::IDENTITY,
293 initialize::INITIALIZE,
294 inspect::INSPECT,
295 join::JOIN,
296 join_fused::JOIN_FUSED,
297 join_fused_lhs::JOIN_FUSED_LHS,
298 join_fused_rhs::JOIN_FUSED_RHS,
299 join_multiset::JOIN_MULTISET,
300 fold_keyed::FOLD_KEYED,
301 reduce_keyed::REDUCE_KEYED,
302 repeat_n::REPEAT_N,
303 lattice_bimorphism::LATTICE_BIMORPHISM,
305 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
306 lattice_fold::LATTICE_FOLD,
307 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
308 lattice_reduce::LATTICE_REDUCE,
309 map::MAP,
310 union::UNION,
311 multiset_delta::MULTISET_DELTA,
312 next_iteration::NEXT_ITERATION,
313 next_stratum::NEXT_STRATUM,
314 defer_signal::DEFER_SIGNAL,
315 defer_tick::DEFER_TICK,
316 defer_tick_lazy::DEFER_TICK_LAZY,
317 null::NULL,
318 partition::PARTITION,
319 persist::PERSIST,
320 persist_mut::PERSIST_MUT,
321 persist_mut_keyed::PERSIST_MUT_KEYED,
322 prefix::PREFIX,
323 resolve_futures::RESOLVE_FUTURES,
324 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
325 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
326 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
327 reduce::REDUCE,
328 reduce_no_replay::REDUCE_NO_REPLAY,
329 scan::SCAN,
330 spin::SPIN,
331 sort::SORT,
332 sort_by_key::SORT_BY_KEY,
333 source_file::SOURCE_FILE,
334 source_interval::SOURCE_INTERVAL,
335 source_iter::SOURCE_ITER,
336 source_json::SOURCE_JSON,
337 source_stdin::SOURCE_STDIN,
338 source_stream::SOURCE_STREAM,
339 source_stream_serde::SOURCE_STREAM_SERDE,
340 state::STATE,
341 state_by::STATE_BY,
342 tee::TEE,
343 unique::UNIQUE,
344 unzip::UNZIP,
345 zip::ZIP,
346 zip_longest::ZIP_LONGEST,
347];
348
349pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
351 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
352 OnceLock::new();
353 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
354}
355pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
357 if let GraphNode::Operator(operator) = node {
358 find_op_op_constraints(operator)
359 } else {
360 None
361 }
362}
363pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
365 let name = &*operator.name_string();
366 operator_lookup().get(name).copied()
367}
368
369#[derive(Clone)]
371pub struct WriteContextArgs<'a> {
372 pub root: &'a TokenStream,
374 pub context: &'a Ident,
377 pub df_ident: &'a Ident,
381 pub subgraph_id: GraphSubgraphId,
383 pub node_id: GraphNodeId,
385 pub loop_id: Option<GraphLoopId>,
387 pub op_span: Span,
389 pub op_tag: Option<String>,
391 pub work_fn: &'a Ident,
393 pub work_fn_async: &'a Ident,
395
396 pub ident: &'a Ident,
398 pub is_pull: bool,
400 pub inputs: &'a [Ident],
402 pub outputs: &'a [Ident],
404 pub singleton_output_ident: &'a Ident,
406
407 pub op_name: &'static str,
409 pub op_inst: &'a OperatorInstance,
411 pub arguments: &'a Punctuated<Expr, Token![,]>,
417 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
419}
420impl WriteContextArgs<'_> {
421 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
427 Ident::new(
428 &format!(
429 "sg_{:?}_node_{:?}_{}",
430 self.subgraph_id.data(),
431 self.node_id.data(),
432 suffix.as_ref(),
433 ),
434 self.op_span,
435 )
436 }
437
438 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
441 let root = self.root;
442 let variant =
443 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
444 Some(quote_spanned! {self.op_span=>
445 #root::scheduled::graph::StateLifespan::#variant
446 })
447 }
448
449 pub fn persistence_args_disallow_mutable<const N: usize>(
451 &self,
452 diagnostics: &mut Vec<Diagnostic>,
453 ) -> [Persistence; N] {
454 let len = self.op_inst.generics.persistence_args.len();
455 if 0 != len && 1 != len && N != len {
456 diagnostics.push(Diagnostic::spanned(
457 self.op_span,
458 Level::Error,
459 format!(
460 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
461 self.op_name, N
462 ),
463 ));
464 }
465
466 let default_persistence = if self.loop_id.is_some() {
467 Persistence::None
468 } else {
469 Persistence::Tick
470 };
471 let mut out = [default_persistence; N];
472 self.op_inst
473 .generics
474 .persistence_args
475 .iter()
476 .copied()
477 .cycle() .take(N)
479 .enumerate()
480 .filter(|&(_i, p)| {
481 if p == Persistence::Mutable {
482 diagnostics.push(Diagnostic::spanned(
483 self.op_span,
484 Level::Error,
485 format!(
486 "An implementation of `'{}` does not exist",
487 p.to_str_lowercase()
488 ),
489 ));
490 false
491 } else {
492 true
493 }
494 })
495 .for_each(|(i, p)| {
496 out[i] = p;
497 });
498 out
499 }
500}
501
502pub trait RangeTrait<T>: Send + Sync + Debug
504where
505 T: ?Sized,
506{
507 fn start_bound(&self) -> Bound<&T>;
509 fn end_bound(&self) -> Bound<&T>;
511 fn contains(&self, item: &T) -> bool
513 where
514 T: PartialOrd<T>;
515
516 fn human_string(&self) -> String
518 where
519 T: Display + PartialEq,
520 {
521 match (self.start_bound(), self.end_bound()) {
522 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
523
524 (Bound::Included(n), Bound::Included(x)) if n == x => {
525 format!("exactly {}", n)
526 }
527 (Bound::Included(n), Bound::Included(x)) => {
528 format!("at least {} and at most {}", n, x)
529 }
530 (Bound::Included(n), Bound::Excluded(x)) => {
531 format!("at least {} and less than {}", n, x)
532 }
533 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
534 (Bound::Excluded(n), Bound::Included(x)) => {
535 format!("more than {} and at most {}", n, x)
536 }
537 (Bound::Excluded(n), Bound::Excluded(x)) => {
538 format!("more than {} and less than {}", n, x)
539 }
540 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
541 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
542 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
543 }
544 }
545}
546
547impl<R, T> RangeTrait<T> for R
548where
549 R: RangeBounds<T> + Send + Sync + Debug,
550{
551 fn start_bound(&self) -> Bound<&T> {
552 self.start_bound()
553 }
554
555 fn end_bound(&self) -> Bound<&T> {
556 self.end_bound()
557 }
558
559 fn contains(&self, item: &T) -> bool
560 where
561 T: PartialOrd<T>,
562 {
563 self.contains(item)
564 }
565}
566
567#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
569pub enum Persistence {
570 None,
572 Loop,
574 Tick,
576 Static,
578 Mutable,
580}
581impl Persistence {
582 pub fn as_state_lifespan_variant(
584 self,
585 subgraph_id: GraphSubgraphId,
586 loop_id: Option<GraphLoopId>,
587 span: Span,
588 ) -> Option<TokenStream> {
589 match self {
590 Persistence::None => {
591 let sg_ident = subgraph_id.as_ident(span);
592 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
593 }
594 Persistence::Loop => {
595 let loop_ident = loop_id
596 .expect("`Persistence::Loop` outside of a loop context.")
597 .as_ident(span);
598 Some(quote_spanned!(span=> Loop(#loop_ident)))
599 }
600 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
601 Persistence::Static => None,
602 Persistence::Mutable => None,
603 }
604 }
605
606 pub fn to_str_lowercase(self) -> &'static str {
608 match self {
609 Persistence::None => "none",
610 Persistence::Tick => "tick",
611 Persistence::Loop => "loop",
612 Persistence::Static => "static",
613 Persistence::Mutable => "mutable",
614 }
615 }
616}
617
618fn make_missing_runtime_msg(op_name: &str) -> Literal {
620 Literal::string(&format!(
621 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
622 op_name
623 ))
624}
625
626#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
628pub enum OperatorCategory {
629 Map,
631 Filter,
633 Flatten,
635 Fold,
637 KeyedFold,
639 LatticeFold,
641 Persistence,
643 MultiIn,
645 MultiOut,
647 Source,
649 Sink,
651 Control,
653 CompilerFusionOperator,
655 Windowing,
657 Unwindowing,
659}
660impl OperatorCategory {
661 pub fn name(self) -> &'static str {
663 self.get_variant_docs().split_once(":").unwrap().0
664 }
665 pub fn description(self) -> &'static str {
667 self.get_variant_docs().split_once(":").unwrap().1
668 }
669}
670
671#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
673pub enum FloType {
674 Source,
676 Windowing,
678 Unwindowing,
680 NextIteration,
682}