dfir_lang/
parse.rs

1//! AST for surface syntax, modelled on [`syn`]'s ASTs.
2#![allow(clippy::allow_attributes, missing_docs, reason = "internal use")]
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use proc_macro2::{Span, TokenStream};
8use quote::ToTokens;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::token::{Brace, Bracket, Paren};
12use syn::{
13    AngleBracketedGenericArguments, Expr, ExprPath, GenericArgument, Ident, ItemUse, LitInt, Path,
14    PathArguments, PathSegment, Token, braced, bracketed, parenthesized,
15};
16
17use crate::process_singletons::preprocess_singletons;
18
19pub struct DfirCode {
20    pub statements: Vec<DfirStatement>,
21}
22impl Parse for DfirCode {
23    fn parse(input: ParseStream) -> syn::Result<Self> {
24        let mut statements = Vec::new();
25        while !input.is_empty() {
26            statements.push(input.parse()?);
27        }
28        Ok(DfirCode { statements })
29    }
30}
31impl ToTokens for DfirCode {
32    fn to_tokens(&self, tokens: &mut TokenStream) {
33        for statement in self.statements.iter() {
34            statement.to_tokens(tokens);
35        }
36    }
37}
38
39pub enum DfirStatement {
40    Use(ItemUse),
41    Named(NamedStatement),
42    Pipeline(PipelineStatement),
43    Loop(LoopStatement),
44}
45impl Parse for DfirStatement {
46    fn parse(input: ParseStream) -> syn::Result<Self> {
47        let lookahead1 = input.lookahead1();
48        if lookahead1.peek(Token![use]) {
49            Ok(Self::Use(ItemUse::parse(input)?))
50        } else if lookahead1.peek(Paren) || lookahead1.peek(Bracket) || lookahead1.peek(Token![mod])
51        {
52            Ok(Self::Pipeline(PipelineStatement::parse(input)?))
53        } else if lookahead1.peek(Token![loop]) {
54            Ok(Self::Loop(LoopStatement::parse(input)?))
55        } else if lookahead1.peek(Ident) {
56            let fork = input.fork();
57            let _: Path = fork.parse()?;
58            let lookahead2 = fork.lookahead1();
59            if lookahead2.peek(Token![=]) {
60                Ok(Self::Named(NamedStatement::parse(input)?))
61            } else if lookahead2.peek(Token![->])
62                || lookahead2.peek(Paren)
63                || lookahead2.peek(Bracket)
64            {
65                Ok(Self::Pipeline(PipelineStatement::parse(input)?))
66            } else {
67                Err(lookahead2.error())
68            }
69        } else {
70            Err(lookahead1.error())
71        }
72    }
73}
74impl ToTokens for DfirStatement {
75    fn to_tokens(&self, tokens: &mut TokenStream) {
76        match self {
77            Self::Use(x) => x.to_tokens(tokens),
78            Self::Named(x) => x.to_tokens(tokens),
79            Self::Pipeline(x) => x.to_tokens(tokens),
80            Self::Loop(x) => x.to_tokens(tokens),
81        }
82    }
83}
84
85pub struct NamedStatement {
86    pub name: Ident,
87    pub equals: Token![=],
88    pub pipeline: Pipeline,
89    pub semi_token: Token![;],
90}
91impl Parse for NamedStatement {
92    fn parse(input: ParseStream) -> syn::Result<Self> {
93        let name = input.parse()?;
94        let equals = input.parse()?;
95        let pipeline = input.parse()?;
96        let semi_token = input.parse()?;
97        Ok(Self {
98            name,
99            equals,
100            pipeline,
101            semi_token,
102        })
103    }
104}
105impl ToTokens for NamedStatement {
106    fn to_tokens(&self, tokens: &mut TokenStream) {
107        self.name.to_tokens(tokens);
108        self.equals.to_tokens(tokens);
109        self.pipeline.to_tokens(tokens);
110        self.semi_token.to_tokens(tokens);
111    }
112}
113
114pub struct PipelineStatement {
115    pub pipeline: Pipeline,
116    pub semi_token: Token![;],
117}
118impl Parse for PipelineStatement {
119    fn parse(input: ParseStream) -> syn::Result<Self> {
120        let pipeline = input.parse()?;
121        let semi_token = input.parse()?;
122        Ok(Self {
123            pipeline,
124            semi_token,
125        })
126    }
127}
128impl ToTokens for PipelineStatement {
129    fn to_tokens(&self, tokens: &mut TokenStream) {
130        self.pipeline.to_tokens(tokens);
131        self.semi_token.to_tokens(tokens);
132    }
133}
134
135#[derive(Clone, Debug)]
136pub enum Pipeline {
137    Paren(Ported<PipelineParen>),
138    Name(Ported<Ident>),
139    Link(PipelineLink),
140    Operator(Operator),
141    ModuleBoundary(Ported<Token![mod]>),
142}
143impl Pipeline {
144    fn parse_one(input: ParseStream) -> syn::Result<Self> {
145        let lookahead1 = input.lookahead1();
146
147        // Leading indexing
148        if lookahead1.peek(Bracket) {
149            let inn_idx = input.parse()?;
150            let lookahead2 = input.lookahead1();
151            // Indexed paren
152            if lookahead2.peek(Paren) {
153                Ok(Self::Paren(Ported::parse_rest(Some(inn_idx), input)?))
154            }
155            // Indexed name
156            else if lookahead2.peek(Ident) {
157                Ok(Self::Name(Ported::parse_rest(Some(inn_idx), input)?))
158            }
159            // Indexed module boundary
160            else if lookahead2.peek(Token![mod]) {
161                Ok(Self::ModuleBoundary(Ported::parse_rest(
162                    Some(inn_idx),
163                    input,
164                )?))
165            }
166            // Emit lookahead expected tokens errors.
167            else {
168                Err(lookahead2.error())
169            }
170        // module input/output
171        } else if lookahead1.peek(Token![mod]) {
172            Ok(Self::ModuleBoundary(input.parse()?))
173        // Ident or macro-style expression
174        } else if lookahead1.peek(Ident) {
175            let speculative = input.fork();
176            let _ident: Ident = speculative.parse()?;
177
178            // If has paren or generic next, it's an operator
179            if speculative.peek(Paren)
180                || speculative.peek(Token![<])
181                || speculative.peek(Token![::])
182            {
183                Ok(Self::Operator(input.parse()?))
184            }
185            // Otherwise it's a variable name
186            else {
187                Ok(Self::Name(input.parse()?))
188            }
189        }
190        // Paren group
191        else if lookahead1.peek(Paren) {
192            Ok(Self::Paren(input.parse()?))
193        }
194        // Emit lookahead expected tokens errors.
195        else {
196            Err(lookahead1.error())
197        }
198    }
199}
200impl Parse for Pipeline {
201    fn parse(input: ParseStream) -> syn::Result<Self> {
202        let lhs = Pipeline::parse_one(input)?;
203        if input.is_empty() || input.peek(Token![;]) {
204            Ok(lhs)
205        } else {
206            let arrow = input.parse()?;
207            let rhs = input.parse()?;
208            let lhs = Box::new(lhs);
209            Ok(Self::Link(PipelineLink { lhs, arrow, rhs }))
210        }
211    }
212}
213impl ToTokens for Pipeline {
214    fn to_tokens(&self, tokens: &mut TokenStream) {
215        match self {
216            Self::Paren(x) => x.to_tokens(tokens),
217            Self::Link(x) => x.to_tokens(tokens),
218            Self::Name(x) => x.to_tokens(tokens),
219            Self::Operator(x) => x.to_tokens(tokens),
220            Self::ModuleBoundary(x) => x.to_tokens(tokens),
221        }
222    }
223}
224
225pub struct LoopStatement {
226    pub loop_token: Token![loop],
227    pub ident: Option<Ident>,
228    pub brace_token: Brace,
229    pub statements: Vec<DfirStatement>,
230    pub semi_token: Token![;],
231}
232impl Parse for LoopStatement {
233    fn parse(input: ParseStream) -> syn::Result<Self> {
234        let loop_token = input.parse()?;
235        let ident = input.parse()?;
236        let content;
237        let brace_token = braced!(content in input);
238        let mut statements = Vec::new();
239        while !content.is_empty() {
240            statements.push(content.parse()?);
241        }
242        let semi_token = input.parse()?;
243        Ok(Self {
244            loop_token,
245            ident,
246            brace_token,
247            statements,
248            semi_token,
249        })
250    }
251}
252impl ToTokens for LoopStatement {
253    fn to_tokens(&self, tokens: &mut TokenStream) {
254        self.loop_token.to_tokens(tokens);
255        self.ident.to_tokens(tokens);
256        self.brace_token.surround(tokens, |tokens| {
257            for statement in self.statements.iter() {
258                statement.to_tokens(tokens);
259            }
260        });
261        self.semi_token.to_tokens(tokens);
262    }
263}
264
265#[derive(Clone, Debug)]
266pub struct Ported<Inner> {
267    pub inn: Option<Indexing>,
268    pub inner: Inner,
269    pub out: Option<Indexing>,
270}
271impl<Inner> Ported<Inner>
272where
273    Inner: Parse,
274{
275    /// The caller will often parse the first port (`inn`) as part of determining what to parse
276    /// next, so this will do the rest after that.
277    fn parse_rest(inn: Option<Indexing>, input: ParseStream) -> syn::Result<Self> {
278        let inner = input.parse()?;
279        let out = input.call(Indexing::parse_opt)?;
280        Ok(Self { inn, inner, out })
281    }
282}
283impl<Inner> Parse for Ported<Inner>
284where
285    Inner: Parse,
286{
287    fn parse(input: ParseStream) -> syn::Result<Self> {
288        let inn = input.call(Indexing::parse_opt)?;
289        Self::parse_rest(inn, input)
290    }
291}
292impl<Inner> ToTokens for Ported<Inner>
293where
294    Inner: ToTokens,
295{
296    fn to_tokens(&self, tokens: &mut TokenStream) {
297        self.inn.to_tokens(tokens);
298        self.inner.to_tokens(tokens);
299        self.out.to_tokens(tokens);
300    }
301}
302
303#[derive(Clone, Debug)]
304pub struct PipelineParen {
305    pub paren_token: Paren,
306    pub pipeline: Box<Pipeline>,
307}
308impl Parse for PipelineParen {
309    fn parse(input: ParseStream) -> syn::Result<Self> {
310        let content;
311        let paren_token = parenthesized!(content in input);
312        let pipeline = content.parse()?;
313        Ok(Self {
314            paren_token,
315            pipeline,
316        })
317    }
318}
319impl ToTokens for PipelineParen {
320    fn to_tokens(&self, tokens: &mut TokenStream) {
321        self.paren_token.surround(tokens, |tokens| {
322            self.pipeline.to_tokens(tokens);
323        });
324    }
325}
326
327#[derive(Clone, Debug)]
328pub struct PipelineLink {
329    pub lhs: Box<Pipeline>,
330    pub arrow: Token![->],
331    pub rhs: Box<Pipeline>,
332}
333impl Parse for PipelineLink {
334    fn parse(input: ParseStream) -> syn::Result<Self> {
335        let lhs = input.parse()?;
336        let arrow = input.parse()?;
337        let rhs = input.parse()?;
338
339        Ok(Self { lhs, arrow, rhs })
340    }
341}
342impl ToTokens for PipelineLink {
343    fn to_tokens(&self, tokens: &mut TokenStream) {
344        self.lhs.to_tokens(tokens);
345        self.arrow.to_tokens(tokens);
346        self.rhs.to_tokens(tokens);
347    }
348}
349
350#[derive(Clone, Debug)]
351pub struct Indexing {
352    pub bracket_token: Bracket,
353    pub index: PortIndex,
354}
355impl Indexing {
356    fn parse_opt(input: ParseStream) -> syn::Result<Option<Self>> {
357        input.peek(Bracket).then(|| input.parse()).transpose()
358    }
359}
360impl Parse for Indexing {
361    fn parse(input: ParseStream) -> syn::Result<Self> {
362        let content;
363        let bracket_token = bracketed!(content in input);
364        let index = content.parse()?;
365        Ok(Self {
366            bracket_token,
367            index,
368        })
369    }
370}
371impl ToTokens for Indexing {
372    fn to_tokens(&self, tokens: &mut TokenStream) {
373        self.bracket_token.surround(tokens, |tokens| {
374            self.index.to_tokens(tokens);
375        });
376    }
377}
378
379/// Port can either be an int or a name (path).
380#[derive(Clone, Debug)]
381pub enum PortIndex {
382    Int(IndexInt),
383    Path(ExprPath),
384}
385impl Parse for PortIndex {
386    fn parse(input: ParseStream) -> syn::Result<Self> {
387        let lookahead = input.lookahead1();
388        if lookahead.peek(LitInt) {
389            input.parse().map(Self::Int)
390        } else {
391            input.parse().map(Self::Path)
392        }
393    }
394}
395impl ToTokens for PortIndex {
396    fn to_tokens(&self, tokens: &mut TokenStream) {
397        match self {
398            PortIndex::Int(index_int) => index_int.to_tokens(tokens),
399            PortIndex::Path(expr_path) => expr_path.to_tokens(tokens),
400        }
401    }
402}
403
404struct TypeHintRemover;
405impl syn::visit_mut::VisitMut for TypeHintRemover {
406    fn visit_expr_mut(&mut self, expr: &mut Expr) {
407        if let Expr::Call(expr_call) = expr {
408            if let Expr::Path(path) = expr_call.func.as_ref() {
409                // if it is a call of the form `::...::*_type_hint(xyz)`,
410                // typically `::stageleft::...`, replace it with `xyz`
411                if path
412                    .path
413                    .segments
414                    .last()
415                    .unwrap()
416                    .ident
417                    .to_string()
418                    .ends_with("_type_hint")
419                {
420                    *expr = expr_call.args.first().unwrap().clone();
421                }
422            }
423        }
424
425        syn::visit_mut::visit_expr_mut(self, expr);
426    }
427}
428
429#[derive(Clone)]
430pub struct Operator {
431    pub path: Path,
432    pub paren_token: Paren,
433    pub args_raw: TokenStream,
434    pub args: Punctuated<Expr, Token![,]>,
435    pub singletons_referenced: Vec<Ident>,
436}
437
438impl Operator {
439    pub fn name(&self) -> Path {
440        Path {
441            leading_colon: self.path.leading_colon,
442            segments: self
443                .path
444                .segments
445                .iter()
446                .map(|seg| PathSegment {
447                    ident: seg.ident.clone(),
448                    arguments: PathArguments::None,
449                })
450                .collect(),
451        }
452    }
453
454    pub fn name_string(&self) -> String {
455        self.name().to_token_stream().to_string()
456    }
457
458    pub fn type_arguments(&self) -> Option<&Punctuated<GenericArgument, Token![,]>> {
459        let end = self.path.segments.last()?;
460        if let PathArguments::AngleBracketed(type_args) = &end.arguments {
461            Some(&type_args.args)
462        } else {
463            None
464        }
465    }
466
467    pub fn args(&self) -> &Punctuated<Expr, Token![,]> {
468        &self.args
469    }
470
471    /// Output the operator as a formatted string using `prettyplease`.
472    pub fn to_pretty_string(&self) -> String {
473        // TODO(mingwei): preserve #args_raw instead of just args?
474        let mut file: syn::File = syn::parse_quote! {
475            fn main() {
476                #self
477            }
478        };
479
480        syn::visit_mut::visit_file_mut(&mut TypeHintRemover, &mut file);
481        let str = prettyplease::unparse(&file);
482        str.trim_start()
483            .trim_start_matches("fn main()")
484            .trim_start()
485            .trim_start_matches('{')
486            .trim_start()
487            .trim_end()
488            .trim_end_matches('}')
489            .trim_end()
490            .replace("\n    ", "\n") // Remove extra leading indent
491    }
492}
493impl Parse for Operator {
494    fn parse(input: ParseStream) -> syn::Result<Self> {
495        let path: Path = input.parse()?;
496        if let Some(path_seg) = path.segments.iter().find(|path_seg| {
497            matches!(
498                &path_seg.arguments,
499                PathArguments::AngleBracketed(AngleBracketedGenericArguments {
500                    colon2_token: None,
501                    ..
502                })
503            )
504        }) {
505            return Err(syn::Error::new_spanned(
506                path_seg,
507                "Missing `::` before `<...>` generic arguments",
508            ));
509        }
510
511        let content;
512        let paren_token = parenthesized!(content in input);
513        let args_raw: TokenStream = content.parse()?;
514        let mut singletons_referenced = Vec::new();
515        let args = parse_terminated(preprocess_singletons(
516            args_raw.clone(),
517            &mut singletons_referenced,
518        ))?;
519
520        Ok(Self {
521            path,
522            paren_token,
523            args_raw,
524            args,
525            singletons_referenced,
526        })
527    }
528}
529
530impl ToTokens for Operator {
531    fn to_tokens(&self, tokens: &mut TokenStream) {
532        self.path.to_tokens(tokens);
533        self.paren_token.surround(tokens, |tokens| {
534            self.args.to_tokens(tokens);
535        });
536    }
537}
538
539impl Debug for Operator {
540    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541        f.debug_struct("Operator")
542            .field("path", &self.path.to_token_stream().to_string())
543            .field(
544                "args",
545                &self
546                    .args
547                    .iter()
548                    .map(|a| a.to_token_stream().to_string())
549                    .collect::<Vec<_>>(),
550            )
551            .finish()
552    }
553}
554
555#[derive(Clone, Copy, Debug)]
556pub struct IndexInt {
557    pub value: isize,
558    pub span: Span,
559}
560impl Parse for IndexInt {
561    fn parse(input: ParseStream) -> syn::Result<Self> {
562        let lit_int: LitInt = input.parse()?;
563        let value = lit_int.base10_parse()?;
564        Ok(Self {
565            value,
566            span: lit_int.span(),
567        })
568    }
569}
570impl ToTokens for IndexInt {
571    fn to_tokens(&self, tokens: &mut TokenStream) {
572        let lit_int = LitInt::new(&self.value.to_string(), self.span);
573        lit_int.to_tokens(tokens)
574    }
575}
576impl Hash for IndexInt {
577    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
578        self.value.hash(state);
579    }
580}
581impl PartialEq for IndexInt {
582    fn eq(&self, other: &Self) -> bool {
583        self.value == other.value
584    }
585}
586impl Eq for IndexInt {}
587impl PartialOrd for IndexInt {
588    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
589        Some(self.cmp(other))
590    }
591}
592impl Ord for IndexInt {
593    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
594        self.value.cmp(&other.value)
595    }
596}
597
598pub fn parse_terminated<T, P>(tokens: TokenStream) -> syn::Result<Punctuated<T, P>>
599where
600    T: Parse,
601    P: Parse,
602{
603    struct ParseTerminated<T, P>(pub Punctuated<T, P>);
604    impl<T, P> Parse for ParseTerminated<T, P>
605    where
606        T: Parse,
607        P: Parse,
608    {
609        fn parse(input: ParseStream) -> syn::Result<Self> {
610            Ok(Self(Punctuated::parse_terminated(input)?))
611        }
612    }
613
614    Ok(syn::parse2::<ParseTerminated<T, P>>(tokens)?.0)
615}
616
617#[cfg(test)]
618mod test {
619    use syn::parse_quote;
620
621    use super::*;
622
623    #[test]
624    fn test_operator_to_pretty_string() {
625        let op: Operator = parse_quote! {
626            demux(|(msg, addr), var_args!(clients, msgs, errs)|
627                match msg {
628                    Message::ConnectRequest => clients.give(addr),
629                    Message::ChatMsg {..} => msgs.give(msg),
630                    _ => errs.give(msg),
631                }
632            )
633        };
634        assert_eq!(
635            r"
636demux(|(msg, addr), var_args!(clients, msgs, errs)| match msg {
637    Message::ConnectRequest => clients.give(addr),
638    Message::ChatMsg { .. } => msgs.give(msg),
639    _ => errs.give(msg),
640})
641"
642            .trim(),
643            op.to_pretty_string()
644        );
645    }
646}