1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
use proc_macro2::{Ident, TokenStream};
use quote::{quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{parse_quote, Expr, ExprCall};

use super::{
    DelayType, OpInstGenerics, OperatorCategory, OperatorConstraints,
    OperatorInstance, OperatorWriteOutput, Persistence, WriteContextArgs, RANGE_0, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};

/// > 2 input streams of type <(K, V1)> and <(K, V2)>, 1 output stream of type <(K, (V1, V2))>
///
/// `join_fused` takes two arguments, they are the configuration options for the left hand side and right hand side inputs respectively.
/// There are three available configuration options, they are `Reduce`: if the input type is the same as the accumulator type,
/// `Fold`: if the input type is different from the accumulator type, and the accumulator type has a sensible default value, and
/// `FoldFrom`: if the input type is different from the accumulator type, and the accumulator needs to be derived from the first input value.
/// Examples of all three configuration options are below:
/// ```dfir,ignore
/// // Left hand side input will use fold, right hand side input will use reduce,
/// join_fused(Fold(|| "default value", |x, y| *x += y), Reduce(|x, y| *x -= y))
///
/// // Left hand side input will use FoldFrom, and the right hand side input will use Reduce again
/// join_fused(FoldFrom(|x| "conversion function", |x, y| *x += y), Reduce(|x, y| *x *= y))
/// ```
/// The three currently supported fused operator types are `Fold(Fn() -> A, Fn(A, T) -> A)`, `Reduce(Fn(A, A) -> A)`, and `FoldFrom(Fn(T) -> A, Fn(A, T) -> A)`
///
/// `join_fused` first performs a fold_keyed/reduce_keyed operation on each input stream before performing joining. See `join()`. There is currently no equivalent for `FoldFrom` in dfir operators.
///
/// For example, the following two dfir programs are equivalent, the former would optimize into the latter:
///
/// ```dfir
/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
///     -> reduce_keyed(|x: &mut _, y| *x += y)
///     -> [0]my_join;
/// source_iter(vec![("key", 2), ("key", 3)])
///     -> fold_keyed(|| 1, |x: &mut _, y| *x *= y)
///     -> [1]my_join;
/// my_join = join_multiset()
///     -> assert_eq([("key", (3, 6))]);
/// ```
///
/// ```dfir
/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
///     -> [0]my_join;
/// source_iter(vec![("key", 2), ("key", 3)])
///     -> [1]my_join;
/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
///     -> assert_eq([("key", (3, 6))]);
/// ```
///
/// Here is an example of using FoldFrom to derive the accumulator from the first value:
///
/// ```dfir
/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)])
///     -> [0]my_join;
/// source_iter(vec![("key", 2), ("key", 3)])
///     -> [1]my_join;
/// my_join = join_fused(FoldFrom(|x| x + 3, |x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
///     -> assert_eq([("key", (6, 6))]);
/// ```
///
/// The benefit of this is that the state between the reducing/folding operator and the join is merged together.
///
/// `join_fused` follows the same persistence rules as `join` and all other operators. By default, both the left hand side and right hand side are `'tick` persistence. They can be set to `'static` persistence
/// by specifying `'static` in the type arguments of the operator.
///
/// for `join_fused::<'static>`, the operator will replay all _keys_ that the join has ever seen each tick, and not only the new matches from that specific tick.
/// This means that it behaves identically to if `persist::<'static>()` were placed before the inputs and the persistence of
/// for example, the two following examples have identical behavior:
///
/// ```dfir
/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> persist::<'static>() -> [0]my_join;
/// source_iter(vec![("key", 2)]) -> my_union;
/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
/// my_union = union() -> persist::<'static>() -> [1]my_join;
///
/// my_join = join_fused(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
/// ```
///
/// ```dfir
/// source_iter(vec![("key", 0), ("key", 1), ("key", 2)]) -> [0]my_join;
/// source_iter(vec![("key", 2)]) -> my_union;
/// source_iter(vec![("key", 3)]) -> defer_tick() -> my_union;
/// my_union = union() -> [1]my_join;
///
/// my_join = join_fused::<'static>(Reduce(|x, y| *x += y), Fold(|| 1, |x, y| *x *= y))
///     -> assert_eq([("key", (3, 2)), ("key", (3, 6))]);
/// ```
pub const JOIN_FUSED: OperatorConstraints = OperatorConstraints {
    name: "join_fused",
    categories: &[OperatorCategory::MultiIn],
    hard_range_inn: &(2..=2),
    soft_range_inn: &(2..=2),
    hard_range_out: RANGE_1,
    soft_range_out: RANGE_1,
    num_args: 2,
    persistence_args: &(0..=2),
    type_args: RANGE_0,
    is_external_input: false,
    has_singleton_output: false,
    flo_type: None,
    ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
    ports_out: None,
    input_delaytype_fn: |_| Some(DelayType::Stratum),
    write_fn: |wc @ &WriteContextArgs {
                   context,
                   op_span,
                   ident,
                   inputs,
                   is_pull,
                   op_inst:
                       OperatorInstance {
                           generics:
                               OpInstGenerics {
                                   persistence_args, ..
                               },
                           ..
                       },
                   arguments,
                   ..
               },
               diagnostics| {
        assert!(is_pull);

        let persistences = parse_persistences(persistence_args);

        let lhs_join_options =
            parse_argument(&arguments[0]).map_err(|err| diagnostics.push(err))?;
        let rhs_join_options =
            parse_argument(&arguments[1]).map_err(|err| diagnostics.push(err))?;

        let (lhs_joindata_ident, lhs_borrow_ident, lhs_prologue, lhs_borrow) =
            make_joindata(wc, persistences[0], &lhs_join_options, "lhs")
                .map_err(|err| diagnostics.push(err))?;

        let (rhs_joindata_ident, rhs_borrow_ident, rhs_prologue, rhs_borrow) =
            make_joindata(wc, persistences[1], &rhs_join_options, "rhs")
                .map_err(|err| diagnostics.push(err))?;

        let write_prologue = quote_spanned! {op_span=>
            #lhs_prologue
            #rhs_prologue
        };

        let lhs = &inputs[0];
        let rhs = &inputs[1];

        let arg0_span = arguments[0].span();
        let arg1_span = arguments[1].span();

        let lhs_tokens = match lhs_join_options {
            JoinOptions::FoldFrom(lhs_from, lhs_fold) => quote_spanned! {arg0_span=>
                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_from);
            },
            JoinOptions::Fold(lhs_default, lhs_fold) => quote_spanned! {arg0_span=>
                #lhs_borrow.fold_into(#lhs, #lhs_fold, #lhs_default);
            },
            JoinOptions::Reduce(lhs_reduce) => quote_spanned! {arg0_span=>
                #lhs_borrow.reduce_into(#lhs, #lhs_reduce);
            },
        };

        let rhs_tokens = match rhs_join_options {
            JoinOptions::FoldFrom(rhs_from, rhs_fold) => quote_spanned! {arg0_span=>
                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_from);
            },
            JoinOptions::Fold(rhs_default, rhs_fold) => quote_spanned! {arg1_span=>
                #rhs_borrow.fold_into(#rhs, #rhs_fold, #rhs_default);
            },
            JoinOptions::Reduce(rhs_reduce) => quote_spanned! {arg1_span=>
                #rhs_borrow.reduce_into(#rhs, #rhs_reduce);
            },
        };

        // Since both input arguments are stratum blocking then we don't need to keep track of ticks to avoid emitting the same thing twice in the same tick.
        let write_iterator = quote_spanned! {op_span=>
            let mut #lhs_borrow_ident = #context.state_ref(#lhs_joindata_ident).borrow_mut();
            let mut #rhs_borrow_ident = #context.state_ref(#rhs_joindata_ident).borrow_mut();

            let #ident = {
                #lhs_tokens
                #rhs_tokens

                // TODO: start the iterator with the smallest len() table rather than always picking rhs.
                #[allow(clippy::clone_on_copy)]
                #[allow(suspicious_double_ref_op)]
                #rhs_borrow
                    .table
                    .iter()
                    .filter_map(|(k, v2)| #lhs_borrow.table.get(k).map(|v1| (k.clone(), (v1.clone(), v2.clone()))))
            };
        };

        let write_iterator_after =
            if persistences[0] == Persistence::Static || persistences[1] == Persistence::Static {
                quote_spanned! {op_span=>
                    // TODO: Probably only need to schedule if #*_borrow.len() > 0?
                    #context.schedule_subgraph(#context.current_subgraph(), false);
                }
            } else {
                quote_spanned! {op_span=>}
            };

        Ok(OperatorWriteOutput {
            write_prologue,
            write_iterator,
            write_iterator_after,
        })
    },
};

pub(crate) enum JoinOptions<'a> {
    FoldFrom(&'a Expr, &'a Expr),
    Fold(&'a Expr, &'a Expr),
    Reduce(&'a Expr),
}

pub(crate) fn parse_argument(arg: &Expr) -> Result<JoinOptions, Diagnostic> {
    let Expr::Call(ExprCall {
        attrs: _,
        func,
        paren_token: _,
        args,
    }) = arg
    else {
        return Err(Diagnostic::spanned(
            arg.span(),
            Level::Error,
            format!("Argument must be a function call: {arg:?}"),
        ));
    };

    let mut elems = args.iter();
    let func_name = func.to_token_stream().to_string();

    match func_name.as_str() {
        "Fold" => match (elems.next(), elems.next()) {
            (Some(default), Some(fold)) => Ok(JoinOptions::Fold(default, fold)),
            _ => {
                Err(Diagnostic::spanned(
                        args.span(),
                        Level::Error,
                        format!("Fold requires two arguments, first is the default function, second is the folding function: {func:?}"),
                    ))
            }
        },
        "FoldFrom" => match (elems.next(), elems.next()) {
            (Some(from), Some(fold)) => Ok(JoinOptions::FoldFrom(from, fold)),
            _ => {
                Err(Diagnostic::spanned(
                    args.span(),
                    Level::Error,
                    format!("FoldFrom requires two arguments, first is the From function, second is the folding function: {func:?}"),
                ))
            }
        },
        "Reduce" => match elems.next() {
            Some(reduce) => Ok(JoinOptions::Reduce(reduce)),
            _ => Err(Diagnostic::spanned(
                args.span(),
                Level::Error,
                format!("Reduce requires one argument, the reducing function: {func:?}"),
            )),
        },
        _ => Err(Diagnostic::spanned(
            func.span(),
            Level::Error,
            format!("Unknown summarizing function: {func:?}"),
        )),
    }
}

pub(crate) fn make_joindata(
    wc: &WriteContextArgs,
    persistence: Persistence,
    join_options: &JoinOptions<'_>,
    side: &str,
) -> Result<(Ident, Ident, TokenStream, TokenStream), Diagnostic> {
    let joindata_ident = wc.make_ident(format!("joindata_{}", side));
    let borrow_ident = wc.make_ident(format!("joindata_{}_borrow", side));

    let context = wc.context;
    let hydroflow = wc.hydroflow;
    let root = wc.root;
    let op_span = wc.op_span;

    let join_type = match *join_options {
        JoinOptions::FoldFrom(_, _) => {
            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFoldFrom)
        }
        JoinOptions::Fold(_, _) => {
            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateFold)
        }
        JoinOptions::Reduce(_) => {
            quote_spanned!(op_span=> #root::compiled::pull::HalfJoinStateReduce)
        }
    };

    let (prologue, borrow) = match persistence {
        Persistence::Tick => (
            quote_spanned! {op_span=>
                let #joindata_ident = #hydroflow.add_state(std::cell::RefCell::new(
                    #root::util::monotonic_map::MonotonicMap::new_init(
                        #join_type::default()
                    )
                ));
            },
            quote_spanned! {op_span=>
                #borrow_ident.get_mut_clear(#context.current_tick())
            },
        ),
        Persistence::Static => (
            quote_spanned! {op_span=>
                let #joindata_ident = #hydroflow.add_state(std::cell::RefCell::new(
                    #join_type::default()
                ));
            },
            quote_spanned! {op_span=>
                #borrow_ident
            },
        ),
        Persistence::Mutable => {
            return Err(Diagnostic::spanned(
                op_span,
                Level::Error,
                "An implementation of 'mutable does not exist",
            ));
        }
    };
    Ok((joindata_ident, borrow_ident, prologue, borrow))
}

pub(crate) fn parse_persistences(persistences: &[Persistence]) -> [Persistence; 2] {
    match persistences {
        [] => [Persistence::Tick, Persistence::Tick],
        [a] => [*a, *a],
        [a, b] => [*a, *b],
        _ => panic!("Too many persistences: {persistences:?}"),
    }
}