dfir_lang/
parse.rs

1//! AST for surface syntax, modelled on [`syn`]'s ASTs.
2#![allow(clippy::allow_attributes, missing_docs, reason = "internal use")]
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use proc_macro2::{Span, TokenStream};
8use quote::ToTokens;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::token::{Brace, Bracket, Paren};
12use syn::{
13    AngleBracketedGenericArguments, Expr, ExprPath, GenericArgument, Ident, ItemUse, LitInt, Path,
14    PathArguments, PathSegment, Token, braced, bracketed, parenthesized,
15};
16
17use crate::process_singletons::preprocess_singletons;
18
19pub struct DfirCode {
20    pub statements: Vec<DfirStatement>,
21}
22impl Parse for DfirCode {
23    fn parse(input: ParseStream) -> syn::Result<Self> {
24        let mut statements = Vec::new();
25        while !input.is_empty() {
26            statements.push(input.parse()?);
27        }
28        Ok(DfirCode { statements })
29    }
30}
31impl ToTokens for DfirCode {
32    fn to_tokens(&self, tokens: &mut TokenStream) {
33        for statement in self.statements.iter() {
34            statement.to_tokens(tokens);
35        }
36    }
37}
38
39pub enum DfirStatement {
40    Use(ItemUse),
41    Named(NamedStatement),
42    Pipeline(PipelineStatement),
43    Loop(LoopStatement),
44}
45impl Parse for DfirStatement {
46    fn parse(input: ParseStream) -> syn::Result<Self> {
47        let lookahead1 = input.lookahead1();
48        if lookahead1.peek(Token![use]) {
49            Ok(Self::Use(ItemUse::parse(input)?))
50        } else if lookahead1.peek(Paren) || lookahead1.peek(Bracket) || lookahead1.peek(Token![mod])
51        {
52            Ok(Self::Pipeline(PipelineStatement::parse(input)?))
53        } else if lookahead1.peek(Token![loop]) {
54            Ok(Self::Loop(LoopStatement::parse(input)?))
55        } else if lookahead1.peek(Ident) {
56            let fork = input.fork();
57            let _: Path = fork.parse()?;
58            let lookahead2 = fork.lookahead1();
59            if lookahead2.peek(Token![=]) {
60                Ok(Self::Named(NamedStatement::parse(input)?))
61            } else if lookahead2.peek(Token![->])
62                || lookahead2.peek(Paren)
63                || lookahead2.peek(Bracket)
64            {
65                Ok(Self::Pipeline(PipelineStatement::parse(input)?))
66            } else {
67                Err(lookahead2.error())
68            }
69        } else {
70            Err(lookahead1.error())
71        }
72    }
73}
74impl ToTokens for DfirStatement {
75    fn to_tokens(&self, tokens: &mut TokenStream) {
76        match self {
77            Self::Use(x) => x.to_tokens(tokens),
78            Self::Named(x) => x.to_tokens(tokens),
79            Self::Pipeline(x) => x.to_tokens(tokens),
80            Self::Loop(x) => x.to_tokens(tokens),
81        }
82    }
83}
84
85pub struct NamedStatement {
86    pub name: Ident,
87    pub equals: Token![=],
88    pub pipeline: Pipeline,
89    pub semi_token: Token![;],
90}
91impl Parse for NamedStatement {
92    fn parse(input: ParseStream) -> syn::Result<Self> {
93        let name = input.parse()?;
94        let equals = input.parse()?;
95        let pipeline = input.parse()?;
96        let semi_token = input.parse()?;
97        Ok(Self {
98            name,
99            equals,
100            pipeline,
101            semi_token,
102        })
103    }
104}
105impl ToTokens for NamedStatement {
106    fn to_tokens(&self, tokens: &mut TokenStream) {
107        self.name.to_tokens(tokens);
108        self.equals.to_tokens(tokens);
109        self.pipeline.to_tokens(tokens);
110        self.semi_token.to_tokens(tokens);
111    }
112}
113
114pub struct PipelineStatement {
115    pub pipeline: Pipeline,
116    pub semi_token: Token![;],
117}
118impl Parse for PipelineStatement {
119    fn parse(input: ParseStream) -> syn::Result<Self> {
120        let pipeline = input.parse()?;
121        let semi_token = input.parse()?;
122        Ok(Self {
123            pipeline,
124            semi_token,
125        })
126    }
127}
128impl ToTokens for PipelineStatement {
129    fn to_tokens(&self, tokens: &mut TokenStream) {
130        self.pipeline.to_tokens(tokens);
131        self.semi_token.to_tokens(tokens);
132    }
133}
134
135#[derive(Clone, Debug)]
136pub enum Pipeline {
137    Paren(Ported<PipelineParen>),
138    Name(Ported<Ident>),
139    Link(PipelineLink),
140    Operator(Operator),
141    ModuleBoundary(Ported<Token![mod]>),
142}
143impl Pipeline {
144    fn parse_one(input: ParseStream) -> syn::Result<Self> {
145        let lookahead1 = input.lookahead1();
146
147        // Leading indexing
148        if lookahead1.peek(Bracket) {
149            let inn_idx = input.parse()?;
150            let lookahead2 = input.lookahead1();
151            // Indexed paren
152            if lookahead2.peek(Paren) {
153                Ok(Self::Paren(Ported::parse_rest(Some(inn_idx), input)?))
154            }
155            // Indexed name
156            else if lookahead2.peek(Ident) {
157                Ok(Self::Name(Ported::parse_rest(Some(inn_idx), input)?))
158            }
159            // Indexed module boundary
160            else if lookahead2.peek(Token![mod]) {
161                Ok(Self::ModuleBoundary(Ported::parse_rest(
162                    Some(inn_idx),
163                    input,
164                )?))
165            }
166            // Emit lookahead expected tokens errors.
167            else {
168                Err(lookahead2.error())
169            }
170        // module input/output
171        } else if lookahead1.peek(Token![mod]) {
172            Ok(Self::ModuleBoundary(input.parse()?))
173        // Ident or macro-style expression
174        } else if lookahead1.peek(Ident) {
175            let speculative = input.fork();
176            let _ident: Ident = speculative.parse()?;
177
178            // If has paren or generic next, it's an operator
179            if speculative.peek(Paren)
180                || speculative.peek(Token![<])
181                || speculative.peek(Token![::])
182            {
183                Ok(Self::Operator(input.parse()?))
184            }
185            // Otherwise it's a variable name
186            else {
187                Ok(Self::Name(input.parse()?))
188            }
189        }
190        // Paren group
191        else if lookahead1.peek(Paren) {
192            Ok(Self::Paren(input.parse()?))
193        }
194        // Emit lookahead expected tokens errors.
195        else {
196            Err(lookahead1.error())
197        }
198    }
199}
200impl Parse for Pipeline {
201    fn parse(input: ParseStream) -> syn::Result<Self> {
202        let lhs = Pipeline::parse_one(input)?;
203        if input.is_empty() || input.peek(Token![;]) {
204            Ok(lhs)
205        } else {
206            let arrow = input.parse()?;
207            let rhs = input.parse()?;
208            let lhs = Box::new(lhs);
209            Ok(Self::Link(PipelineLink { lhs, arrow, rhs }))
210        }
211    }
212}
213impl ToTokens for Pipeline {
214    fn to_tokens(&self, tokens: &mut TokenStream) {
215        match self {
216            Self::Paren(x) => x.to_tokens(tokens),
217            Self::Link(x) => x.to_tokens(tokens),
218            Self::Name(x) => x.to_tokens(tokens),
219            Self::Operator(x) => x.to_tokens(tokens),
220            Self::ModuleBoundary(x) => x.to_tokens(tokens),
221        }
222    }
223}
224
225pub struct LoopStatement {
226    pub loop_token: Token![loop],
227    pub ident: Option<Ident>,
228    pub brace_token: Brace,
229    pub statements: Vec<DfirStatement>,
230    pub semi_token: Token![;],
231}
232impl Parse for LoopStatement {
233    fn parse(input: ParseStream) -> syn::Result<Self> {
234        let loop_token = input.parse()?;
235        let ident = input.parse()?;
236        let content;
237        let brace_token = braced!(content in input);
238        let mut statements = Vec::new();
239        while !content.is_empty() {
240            statements.push(content.parse()?);
241        }
242        let semi_token = input.parse()?;
243        Ok(Self {
244            loop_token,
245            ident,
246            brace_token,
247            statements,
248            semi_token,
249        })
250    }
251}
252impl ToTokens for LoopStatement {
253    fn to_tokens(&self, tokens: &mut TokenStream) {
254        self.loop_token.to_tokens(tokens);
255        self.ident.to_tokens(tokens);
256        self.brace_token.surround(tokens, |tokens| {
257            for statement in self.statements.iter() {
258                statement.to_tokens(tokens);
259            }
260        });
261        self.semi_token.to_tokens(tokens);
262    }
263}
264
265#[derive(Clone, Debug)]
266pub struct Ported<Inner> {
267    pub inn: Option<Indexing>,
268    pub inner: Inner,
269    pub out: Option<Indexing>,
270}
271impl<Inner> Ported<Inner>
272where
273    Inner: Parse,
274{
275    /// The caller will often parse the first port (`inn`) as part of determining what to parse
276    /// next, so this will do the rest after that.
277    fn parse_rest(inn: Option<Indexing>, input: ParseStream) -> syn::Result<Self> {
278        let inner = input.parse()?;
279        let out = input.call(Indexing::parse_opt)?;
280        Ok(Self { inn, inner, out })
281    }
282}
283impl<Inner> Parse for Ported<Inner>
284where
285    Inner: Parse,
286{
287    fn parse(input: ParseStream) -> syn::Result<Self> {
288        let inn = input.call(Indexing::parse_opt)?;
289        Self::parse_rest(inn, input)
290    }
291}
292impl<Inner> ToTokens for Ported<Inner>
293where
294    Inner: ToTokens,
295{
296    fn to_tokens(&self, tokens: &mut TokenStream) {
297        self.inn.to_tokens(tokens);
298        self.inner.to_tokens(tokens);
299        self.out.to_tokens(tokens);
300    }
301}
302
303#[derive(Clone, Debug)]
304pub struct PipelineParen {
305    pub paren_token: Paren,
306    pub pipeline: Box<Pipeline>,
307}
308impl Parse for PipelineParen {
309    fn parse(input: ParseStream) -> syn::Result<Self> {
310        let content;
311        let paren_token = parenthesized!(content in input);
312        let pipeline = content.parse()?;
313        Ok(Self {
314            paren_token,
315            pipeline,
316        })
317    }
318}
319impl ToTokens for PipelineParen {
320    fn to_tokens(&self, tokens: &mut TokenStream) {
321        self.paren_token.surround(tokens, |tokens| {
322            self.pipeline.to_tokens(tokens);
323        });
324    }
325}
326
327#[derive(Clone, Debug)]
328pub struct PipelineLink {
329    pub lhs: Box<Pipeline>,
330    pub arrow: Token![->],
331    pub rhs: Box<Pipeline>,
332}
333impl Parse for PipelineLink {
334    fn parse(input: ParseStream) -> syn::Result<Self> {
335        let lhs = input.parse()?;
336        let arrow = input.parse()?;
337        let rhs = input.parse()?;
338
339        Ok(Self { lhs, arrow, rhs })
340    }
341}
342impl ToTokens for PipelineLink {
343    fn to_tokens(&self, tokens: &mut TokenStream) {
344        self.lhs.to_tokens(tokens);
345        self.arrow.to_tokens(tokens);
346        self.rhs.to_tokens(tokens);
347    }
348}
349
350#[derive(Clone, Debug)]
351pub struct Indexing {
352    pub bracket_token: Bracket,
353    pub index: PortIndex,
354}
355impl Indexing {
356    fn parse_opt(input: ParseStream) -> syn::Result<Option<Self>> {
357        input.peek(Bracket).then(|| input.parse()).transpose()
358    }
359}
360impl Parse for Indexing {
361    fn parse(input: ParseStream) -> syn::Result<Self> {
362        let content;
363        let bracket_token = bracketed!(content in input);
364        let index = content.parse()?;
365        Ok(Self {
366            bracket_token,
367            index,
368        })
369    }
370}
371impl ToTokens for Indexing {
372    fn to_tokens(&self, tokens: &mut TokenStream) {
373        self.bracket_token.surround(tokens, |tokens| {
374            self.index.to_tokens(tokens);
375        });
376    }
377}
378
379/// Port can either be an int or a name (path).
380#[derive(Clone, Debug)]
381pub enum PortIndex {
382    Int(IndexInt),
383    Path(ExprPath),
384}
385impl Parse for PortIndex {
386    fn parse(input: ParseStream) -> syn::Result<Self> {
387        let lookahead = input.lookahead1();
388        if lookahead.peek(LitInt) {
389            input.parse().map(Self::Int)
390        } else {
391            input.parse().map(Self::Path)
392        }
393    }
394}
395impl ToTokens for PortIndex {
396    fn to_tokens(&self, tokens: &mut TokenStream) {
397        match self {
398            PortIndex::Int(index_int) => index_int.to_tokens(tokens),
399            PortIndex::Path(expr_path) => expr_path.to_tokens(tokens),
400        }
401    }
402}
403
404struct TypeHintRemover;
405impl syn::visit_mut::VisitMut for TypeHintRemover {
406    fn visit_expr_mut(&mut self, expr: &mut Expr) {
407        if let Expr::Call(expr_call) = expr &&
408            let Expr::Path(path) = expr_call.func.as_ref() &&
409                // if it is a call of the form `::...::*_type_hint(xyz)`,
410                // typically `::stageleft::...`, replace it with `xyz`
411                path
412                    .path
413                    .segments
414                    .last()
415                    .unwrap()
416                    .ident
417                    .to_string()
418                    .ends_with("_type_hint")
419        {
420            *expr = expr_call.args.first().unwrap().clone();
421        }
422
423        syn::visit_mut::visit_expr_mut(self, expr);
424    }
425}
426
427#[derive(Clone)]
428pub struct Operator {
429    pub path: Path,
430    pub paren_token: Paren,
431    pub args_raw: TokenStream,
432    pub args: Punctuated<Expr, Token![,]>,
433    pub singletons_referenced: Vec<Ident>,
434}
435
436impl Operator {
437    pub fn name(&self) -> Path {
438        Path {
439            leading_colon: self.path.leading_colon,
440            segments: self
441                .path
442                .segments
443                .iter()
444                .map(|seg| PathSegment {
445                    ident: seg.ident.clone(),
446                    arguments: PathArguments::None,
447                })
448                .collect(),
449        }
450    }
451
452    pub fn name_string(&self) -> String {
453        self.name().to_token_stream().to_string()
454    }
455
456    pub fn type_arguments(&self) -> Option<&Punctuated<GenericArgument, Token![,]>> {
457        let end = self.path.segments.last()?;
458        if let PathArguments::AngleBracketed(type_args) = &end.arguments {
459            Some(&type_args.args)
460        } else {
461            None
462        }
463    }
464
465    pub fn args(&self) -> &Punctuated<Expr, Token![,]> {
466        &self.args
467    }
468
469    /// Output the operator as a formatted string using `prettyplease`.
470    pub fn to_pretty_string(&self) -> String {
471        // TODO(mingwei): preserve #args_raw instead of just args?
472        let mut file: syn::File = syn::parse_quote! {
473            fn main() {
474                #self
475            }
476        };
477
478        syn::visit_mut::visit_file_mut(&mut TypeHintRemover, &mut file);
479        let str = prettyplease::unparse(&file);
480        str.trim_start()
481            .trim_start_matches("fn main()")
482            .trim_start()
483            .trim_start_matches('{')
484            .trim_start()
485            .trim_end()
486            .trim_end_matches('}')
487            .trim_end()
488            .replace("\n    ", "\n") // Remove extra leading indent
489    }
490}
491impl Parse for Operator {
492    fn parse(input: ParseStream) -> syn::Result<Self> {
493        let path: Path = input.parse()?;
494        if let Some(path_seg) = path.segments.iter().find(|path_seg| {
495            matches!(
496                &path_seg.arguments,
497                PathArguments::AngleBracketed(AngleBracketedGenericArguments {
498                    colon2_token: None,
499                    ..
500                })
501            )
502        }) {
503            return Err(syn::Error::new_spanned(
504                path_seg,
505                "Missing `::` before `<...>` generic arguments",
506            ));
507        }
508
509        let content;
510        let paren_token = parenthesized!(content in input);
511        let args_raw: TokenStream = content.parse()?;
512        let mut singletons_referenced = Vec::new();
513        let args = parse_terminated(preprocess_singletons(
514            args_raw.clone(),
515            &mut singletons_referenced,
516        ))?;
517
518        Ok(Self {
519            path,
520            paren_token,
521            args_raw,
522            args,
523            singletons_referenced,
524        })
525    }
526}
527
528impl ToTokens for Operator {
529    fn to_tokens(&self, tokens: &mut TokenStream) {
530        self.path.to_tokens(tokens);
531        self.paren_token.surround(tokens, |tokens| {
532            self.args.to_tokens(tokens);
533        });
534    }
535}
536
537impl Debug for Operator {
538    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539        f.debug_struct("Operator")
540            .field("path", &self.path.to_token_stream().to_string())
541            .field(
542                "args",
543                &self
544                    .args
545                    .iter()
546                    .map(|a| a.to_token_stream().to_string())
547                    .collect::<Vec<_>>(),
548            )
549            .finish()
550    }
551}
552
553#[derive(Clone, Copy, Debug)]
554pub struct IndexInt {
555    pub value: isize,
556    pub span: Span,
557}
558impl Parse for IndexInt {
559    fn parse(input: ParseStream) -> syn::Result<Self> {
560        let lit_int: LitInt = input.parse()?;
561        let value = lit_int.base10_parse()?;
562        Ok(Self {
563            value,
564            span: lit_int.span(),
565        })
566    }
567}
568impl ToTokens for IndexInt {
569    fn to_tokens(&self, tokens: &mut TokenStream) {
570        let lit_int = LitInt::new(&self.value.to_string(), self.span);
571        lit_int.to_tokens(tokens)
572    }
573}
574impl Hash for IndexInt {
575    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
576        self.value.hash(state);
577    }
578}
579impl PartialEq for IndexInt {
580    fn eq(&self, other: &Self) -> bool {
581        self.value == other.value
582    }
583}
584impl Eq for IndexInt {}
585impl PartialOrd for IndexInt {
586    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
587        Some(self.cmp(other))
588    }
589}
590impl Ord for IndexInt {
591    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
592        self.value.cmp(&other.value)
593    }
594}
595
596pub fn parse_terminated<T, P>(tokens: TokenStream) -> syn::Result<Punctuated<T, P>>
597where
598    T: Parse,
599    P: Parse,
600{
601    struct ParseTerminated<T, P>(pub Punctuated<T, P>);
602    impl<T, P> Parse for ParseTerminated<T, P>
603    where
604        T: Parse,
605        P: Parse,
606    {
607        fn parse(input: ParseStream) -> syn::Result<Self> {
608            Ok(Self(Punctuated::parse_terminated(input)?))
609        }
610    }
611
612    Ok(syn::parse2::<ParseTerminated<T, P>>(tokens)?.0)
613}
614
615#[cfg(test)]
616mod test {
617    use syn::parse_quote;
618
619    use super::*;
620
621    #[test]
622    fn test_operator_to_pretty_string() {
623        let op: Operator = parse_quote! {
624            demux(|(msg, addr), var_args!(clients, msgs, errs)|
625                match msg {
626                    Message::ConnectRequest => clients.give(addr),
627                    Message::ChatMsg {..} => msgs.give(msg),
628                    _ => errs.give(msg),
629                }
630            )
631        };
632        assert_eq!(
633            r"
634demux(|(msg, addr), var_args!(clients, msgs, errs)| match msg {
635    Message::ConnectRequest => clients.give(addr),
636    Message::ChatMsg { .. } => msgs.give(msg),
637    _ => errs.give(msg),
638})
639"
640            .trim(),
641            op.to_pretty_string()
642        );
643    }
644}