1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use std::collections::BTreeMap;
use std::marker::PhantomData;

use dfir_lang::graph::{partition_graph, DfirGraph};
use dfir_rs::scheduled::graph::Dfir;
use proc_macro2::TokenStream;
use quote::quote;
use stageleft::runtime_support::FreeVariableWithContext;
use stageleft::QuotedWithContext;

use crate::staging_util::Invariant;

pub struct CompiledFlow<'a, ID> {
    pub(super) hydroflow_ir: BTreeMap<usize, DfirGraph>,
    pub(super) extra_stmts: BTreeMap<usize, Vec<syn::Stmt>>,
    pub(super) _phantom: Invariant<'a, ID>,
}

impl<ID> CompiledFlow<'_, ID> {
    pub fn hydroflow_ir(&self) -> &BTreeMap<usize, DfirGraph> {
        &self.hydroflow_ir
    }

    pub fn take_ir(self) -> BTreeMap<usize, DfirGraph> {
        self.hydroflow_ir
    }
}

impl<'a> CompiledFlow<'a, usize> {
    pub fn with_dynamic_id(
        self,
        id: impl QuotedWithContext<'a, usize, ()>,
    ) -> CompiledFlowWithId<'a> {
        let hydroflow_crate = proc_macro_crate::crate_name("hydro_lang")
            .expect("hydro_lang should be present in `Cargo.toml`");
        let root = match hydroflow_crate {
            proc_macro_crate::FoundCrate::Itself => quote! { hydro_lang::dfir_rs },
            proc_macro_crate::FoundCrate::Name(name) => {
                let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
                quote! { #ident::dfir_rs }
            }
        };

        let mut conditioned_tokens = None;
        for (subgraph_id, flat_graph) in self.hydroflow_ir {
            let partitioned_graph =
                partition_graph(flat_graph).expect("Failed to partition (cycle detected).");

            let mut diagnostics = Vec::new();
            let tokens = partitioned_graph.as_code(&root, true, quote::quote!(), &mut diagnostics);
            let my_extra_stmts = self
                .extra_stmts
                .get(&subgraph_id)
                .cloned()
                .unwrap_or_default();

            if let Some(conditioned_tokens) = conditioned_tokens.as_mut() {
                *conditioned_tokens = syn::parse_quote! {
                    #conditioned_tokens else if __given_id == #subgraph_id {
                        #(#my_extra_stmts)*
                        #tokens
                    }
                };
            } else {
                conditioned_tokens = Some(syn::parse_quote! {
                    if __given_id == #subgraph_id {
                        #(#my_extra_stmts)*
                        #tokens
                    }
                });
            }
        }

        let conditioned_tokens: TokenStream = conditioned_tokens.unwrap();
        let id = id.splice_untyped();
        CompiledFlowWithId {
            tokens: syn::parse_quote!({
                let __given_id = #id;
                #conditioned_tokens else {
                    panic!("Invalid node id: {}", __given_id);
                }
            }),
            _phantom: PhantomData,
        }
    }
}

impl<'a, Ctx> QuotedWithContext<'a, Dfir<'a>, Ctx> for CompiledFlow<'a, ()> {}

impl<'a, Ctx> FreeVariableWithContext<Ctx> for CompiledFlow<'a, ()> {
    type O = Dfir<'a>;

    fn to_tokens(mut self, _ctx: &Ctx) -> (Option<TokenStream>, Option<TokenStream>) {
        let hydroflow_crate = proc_macro_crate::crate_name("hydro_lang")
            .expect("hydro_lang should be present in `Cargo.toml`");
        let root = match hydroflow_crate {
            proc_macro_crate::FoundCrate::Itself => quote! { hydro_lang::dfir_rs },
            proc_macro_crate::FoundCrate::Name(name) => {
                let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
                quote! { #ident::dfir_rs }
            }
        };

        if self.hydroflow_ir.len() != 1 {
            panic!("Expected exactly one subgraph in the Hydroflow IR");
        }

        let flat_graph = self.hydroflow_ir.remove(&0).unwrap();
        let partitioned_graph =
            partition_graph(flat_graph).expect("Failed to partition (cycle detected).");

        let mut diagnostics = Vec::new();
        let tokens = partitioned_graph.as_code(&root, true, quote::quote!(), &mut diagnostics);

        (None, Some(tokens))
    }
}

pub struct CompiledFlowWithId<'a> {
    tokens: TokenStream,
    _phantom: Invariant<'a>,
}

impl<'a, Ctx> QuotedWithContext<'a, Dfir<'a>, Ctx> for CompiledFlowWithId<'a> {}

impl<'a, Ctx> FreeVariableWithContext<Ctx> for CompiledFlowWithId<'a> {
    type O = Dfir<'a>;

    fn to_tokens(self, _ctx: &Ctx) -> (Option<TokenStream>, Option<TokenStream>) {
        (None, Some(self.tokens))
    }
}