core_extensions_proc_macros/
splitting_generics.rs

1use crate::{
2    used_proc_macro::{
3        token_stream::IntoIter,
4        Delimiter, Punct, Spacing, Span, TokenStream, TokenTree
5    },
6    parsing_shared::{MacroInvocation, out_parenthesized, parse_paren_args},
7    mmatches,
8};
9
10use core::iter::{Peekable, once};
11
12use alloc::string::ToString;
13
14
15
16
17pub(crate) trait PostGenericsParser {
18    fn consume_token(&mut self, _: &SplitGenerics, tt: TokenTree);
19    fn write_tokens(self, ts: &mut TokenStream);
20}
21
22
23
24pub(crate) struct SplitGenerics {
25    // The parsed tokens from the generic parameter list to after the where clause
26    parsing: Peekable<IntoIter>,
27    curr_is_joint: bool,
28    prev_is_joint: bool,
29    curr_token_kind: TokenKind,
30    prev_token_kind: TokenKind,
31    location: ParseLocation,
32    depth: u32,
33    last_span: Span,
34    generics: TokenStream,
35    generics_span: Span,
36    where_clause: TokenStream,
37    where_clause_span: Span,
38    after_where: TokenStream,
39    after_where_span: Span,
40}
41
42
43impl SplitGenerics {
44    pub(crate) fn new<I>(input_tokens: I) -> Self 
45    where
46        I: IntoIterator<IntoIter = IntoIter, Item = TokenTree>
47    {
48        let mut input_tokens = input_tokens.into_iter();
49
50        let parsed_tt = input_tokens.next().expect("skip_generics expected more tokens");
51
52        let parsing = parse_paren_args(&parsed_tt);
53
54        Self::some_consumed(parsing)
55    }
56
57    pub(crate) fn some_consumed(parsing: Peekable<IntoIter>) -> Self {
58        Self {
59            parsing,
60            curr_is_joint: false,
61            prev_is_joint: false,
62            curr_token_kind: TokenKind::Other,
63            prev_token_kind: TokenKind::Other,
64            depth: 0,
65            location: ParseLocation::InGenerics,
66            last_span: Span::call_site(),
67            generics: TokenStream::new(),
68            generics_span: Span::call_site(),
69            where_clause: TokenStream::new(),
70            where_clause_span: Span::call_site(),
71            after_where: TokenStream::new(),
72            after_where_span: Span::call_site(),
73        }
74    }
75
76    #[allow(dead_code)]
77    pub(crate) fn curr_is_joint(&self) -> bool {
78        self.curr_is_joint
79    }
80
81    #[allow(dead_code)]
82    pub(crate) fn prev_is_joint(&self) -> bool {
83        self.prev_is_joint
84    }
85
86    #[allow(dead_code)]
87    pub(crate) fn depth(&self) -> u32 {
88        self.depth
89    }
90
91    #[allow(dead_code)]
92    pub(crate) fn last_span(&self) -> Span {
93        self.last_span
94    }
95}
96
97macro_rules! match_process_gen {
98    ($res:expr, $tt:ident) => {
99        match $res {
100            Some(tt)=> $tt = tt,
101            None => break,
102        }
103    };
104}
105
106impl SplitGenerics {
107    pub(crate) fn split_generics<P>(mut self, callback_macro: MacroInvocation, args: TokenStream,mut parsing_pgen: P) -> TokenStream
108    where
109        P: PostGenericsParser
110    {
111        self.process_generics();
112
113        self.location = ParseLocation::AfterGenerics;
114
115        if self.depth == 0 {
116            while let Some(mut tt) = self.parsing.next() {
117                match_process_gen!(self.process_generic_list(tt), tt);
118
119                if self.depth == 0 {
120                    match_process_gen!(self.process_after_generics(tt), tt);
121                }
122
123                parsing_pgen.consume_token(&self, tt);
124            }
125        }
126
127        self.process_from_where_clause();
128
129        let Self{
130            generics, generics_span,
131            where_clause, where_clause_span,
132            after_where, after_where_span,
133            ..
134        } = self;
135
136        callback_macro.expand_with_extra_args(|out_args| {
137            out_args.extend(args);
138
139            out_parenthesized(generics, generics_span, out_args);
140            
141            parsing_pgen.write_tokens(out_args);
142
143            out_parenthesized(where_clause, where_clause_span, out_args);
144            out_parenthesized(after_where, after_where_span, out_args);
145        })
146    }
147
148    // Processes the generic parameters that start the token stream,
149    // those declare the generic parmeters
150    fn process_generics(&mut self) {
151        if mmatches!(
152            self.parsing.peek(),
153            Some(TokenTree::Punct(punct)) if punct.as_char() == '<' 
154        ) {
155            drop(self.parsing.next());
156            while let Some(mut tt) = self.parsing.next() {
157                match_process_gen!(self.process_generic_list(tt), tt);
158                self.generics.extend(once(tt));
159            }
160            self.generics_span = self.last_span;
161        }
162    }
163
164    fn process_from_where_clause(&mut self) {
165        if self.depth == 0 && mmatches!(self.location, ParseLocation::InWhere) {
166            while let Some(mut tt) = self.parsing.next() {
167                match_process_gen!(self.process_generic_list(tt), tt);
168
169                if self.depth == 0 {
170                    match_process_gen!(self.process_after_generics(tt), tt);
171                }
172
173                self.where_clause.extend(once(tt));
174            }
175        }
176
177        self.where_clause_span = self.last_span;
178
179
180        for tt in &mut self.parsing {
181            self.last_span = tt.span();
182            self.after_where.extend(once(tt));
183        }
184        self.after_where_span = self.last_span;
185    }
186
187    fn process_after_generics(&mut self, tt: TokenTree) -> Option<TokenTree> {
188        match &tt {
189            TokenTree::Ident(ident) if 
190                mmatches!(self.location, ParseLocation::AfterGenerics) &&
191                ident.to_string() == "where" 
192            => {
193                self.curr_token_kind = TokenKind::Where;
194                self.location = ParseLocation::InWhere;
195                None
196            }
197            TokenTree::Punct(punct) if {
198                let c = punct.as_char();
199                c == ';' || c == '=' && punct.spacing() == Spacing::Alone
200            } => {
201                self.where_clause.extend(self.get_trailing_comma());
202                
203                self.after_where.extend(once(tt));
204                self.location = ParseLocation::AfterWhere;
205                
206                None
207            }
208            TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
209                self.where_clause.extend(self.get_trailing_comma());
210                
211                self.after_where.extend(once(tt));
212                self.location = ParseLocation::AfterWhere;
213                
214                None
215            }
216            _ => Some(tt),
217        }
218    }
219
220    fn get_trailing_comma(&self) -> Option<TokenTree> {
221        if let (ParseLocation::InWhere, TokenKind::Other) = (self.location, self.prev_token_kind) {
222            let mut p = Punct::new(',', Spacing::Alone);
223            p.set_span(self.last_span);
224            Some(TokenTree::Punct(p))
225        } else {
226            None
227        }
228    }
229
230    // Processes any pair of `<` and `>`
231    fn process_generic_list(&mut self, tt: TokenTree) -> Option<TokenTree> {
232        self.last_span = tt.span();
233        self.prev_is_joint = self.curr_is_joint;
234        self.curr_is_joint = false;
235
236        self.prev_token_kind = self.curr_token_kind;
237        self.curr_token_kind = TokenKind::Other;
238
239        if let TokenTree::Punct(punct) = &tt {
240            let char = punct.as_char();
241            self.curr_is_joint = char == '-' ||
242                punct.spacing() == Spacing::Joint && char != '>' && char != '<';
243
244            if char == ',' {
245                self.curr_token_kind = TokenKind::Comma;
246            }
247
248            if char == '<' {
249                self.depth += 1;
250            } if !self.prev_is_joint && char == '>' {
251                if self.depth == 0 {
252                    if mmatches!(self.location, ParseLocation::InGenerics) {
253                        return None;
254                    } 
255                } else {
256                    self.depth -= 1;
257                }
258            }
259        }
260
261        Some(tt)
262    }
263}
264
265
266#[derive(Copy, Clone)]
267enum ParseLocation {
268    InGenerics,
269    AfterGenerics,
270    InWhere,
271    AfterWhere,
272}
273
274
275#[derive(Copy, Clone)]
276enum TokenKind{
277    Where,
278    Comma,
279    Other,
280}
281
282
283