core_extensions_proc_macros/
splitting_generics.rs

1use crate::{
2    used_proc_macro::{
3        token_stream::IntoIter,
4        Delimiter, Punct, Spacing, Span, TokenStream, TokenTree
5    },
6    parsing_shared::{MacroInvocation, out_parenthesized, parse_paren_args, parse_path_and_args},
7    mmatches,
8};
9
10use core::iter::{Peekable, once};
11
12use alloc::string::ToString;
13
14
15
16
17pub(crate) trait PostGenericsParser {
18    fn consume_token(&mut self, _: &SplitGenerics, tt: TokenTree);
19    fn write_tokens(self, ts: &mut TokenStream);
20}
21
22
23
24pub(crate) struct SplitGenerics {
25    // All of the tokens passed to this
26    input_tokens: IntoIter,
27    // The parsed tokens from the generic parameter list to after the where clause
28    parsing: Peekable<IntoIter>,
29    curr_is_joint: bool,
30    prev_is_joint: bool,
31    curr_token_kind: TokenKind,
32    prev_token_kind: TokenKind,
33    location: ParseLocation,
34    depth: u32,
35    last_span: Span,
36    generics: TokenStream,
37    generics_span: Span,
38    where_clause: TokenStream,
39    where_clause_span: Span,
40    after_where: TokenStream,
41    after_where_span: Span,
42}
43
44
45impl SplitGenerics {
46    pub(crate) fn new<I>(input_tokens: I) -> Self 
47    where
48        I: IntoIterator<IntoIter = IntoIter, Item = TokenTree>
49    {
50        let mut input_tokens = input_tokens.into_iter();
51
52        let parsed_tt = input_tokens.next().expect("skip_generics expected more tokens");
53
54        let parsing = parse_paren_args(&parsed_tt);
55
56        Self::some_consumed(input_tokens, parsing)
57    }
58
59    pub(crate) fn some_consumed(input_tokens: IntoIter, parsing: Peekable<IntoIter>) -> Self {
60        Self {
61            input_tokens,
62            parsing,
63            curr_is_joint: false,
64            prev_is_joint: false,
65            curr_token_kind: TokenKind::Other,
66            prev_token_kind: TokenKind::Other,
67            depth: 0,
68            location: ParseLocation::InGenerics,
69            last_span: Span::call_site(),
70            generics: TokenStream::new(),
71            generics_span: Span::call_site(),
72            where_clause: TokenStream::new(),
73            where_clause_span: Span::call_site(),
74            after_where: TokenStream::new(),
75            after_where_span: Span::call_site(),
76        }
77    }
78
79    #[allow(dead_code)]
80    pub(crate) fn curr_is_joint(&self) -> bool {
81        self.curr_is_joint
82    }
83
84    #[allow(dead_code)]
85    pub(crate) fn prev_is_joint(&self) -> bool {
86        self.prev_is_joint
87    }
88
89    #[allow(dead_code)]
90    pub(crate) fn depth(&self) -> u32 {
91        self.depth
92    }
93
94    #[allow(dead_code)]
95    pub(crate) fn last_span(&self) -> Span {
96        self.last_span
97    }
98}
99
100macro_rules! match_process_gen {
101    ($res:expr, $tt:ident) => {
102        match $res {
103            Some(tt)=> $tt = tt,
104            None => break,
105        }
106    };
107}
108
109impl SplitGenerics {
110    pub(crate) fn split_generics<P>(mut self, callback_macro: MacroInvocation, args: TokenStream,mut parsing_pgen: P) -> TokenStream
111    where
112        P: PostGenericsParser
113    {
114        self.process_generics();
115
116        self.location = ParseLocation::AfterGenerics;
117
118        if self.depth == 0 {
119            while let Some(mut tt) = self.parsing.next() {
120                match_process_gen!(self.process_generic_list(tt), tt);
121
122                if self.depth == 0 {
123                    match_process_gen!(self.process_after_generics(tt), tt);
124                }
125
126                parsing_pgen.consume_token(&self, tt);
127            }
128        }
129
130        self.process_from_where_clause();
131
132        let Self{
133            mut input_tokens, 
134            generics, generics_span,
135            where_clause, where_clause_span,
136            after_where, after_where_span,
137            ..
138        } = self;
139
140        callback_macro.expand_with_extra_args(|out_args| {
141            out_args.extend(args);
142
143            out_parenthesized(generics, generics_span, out_args);
144            
145            parsing_pgen.write_tokens(out_args);
146
147            out_parenthesized(where_clause, where_clause_span, out_args);
148            out_parenthesized(after_where, after_where_span, out_args);
149        })
150    }
151
152    // Processes the generic parameters that start the token stream,
153    // those declare the generic parmeters
154    fn process_generics(&mut self) {
155        if mmatches!(
156            self.parsing.peek(),
157            Some(TokenTree::Punct(punct)) if punct.as_char() == '<' 
158        ) {
159            drop(self.parsing.next());
160            while let Some(mut tt) = self.parsing.next() {
161                match_process_gen!(self.process_generic_list(tt), tt);
162                self.generics.extend(once(tt));
163            }
164            self.generics_span = self.last_span;
165        }
166    }
167
168    fn process_from_where_clause(&mut self) {
169        if self.depth == 0 && mmatches!(self.location, ParseLocation::InWhere) {
170            while let Some(mut tt) = self.parsing.next() {
171                match_process_gen!(self.process_generic_list(tt), tt);
172
173                if self.depth == 0 {
174                    match_process_gen!(self.process_after_generics(tt), tt);
175                }
176
177                self.where_clause.extend(once(tt));
178            }
179        }
180
181        self.where_clause_span = self.last_span;
182
183
184        for tt in &mut self.parsing {
185            self.last_span = tt.span();
186            self.after_where.extend(once(tt));
187        }
188        self.after_where_span = self.last_span;
189    }
190
191    fn process_after_generics(&mut self, tt: TokenTree) -> Option<TokenTree> {
192        match &tt {
193            TokenTree::Ident(ident) if 
194                mmatches!(self.location, ParseLocation::AfterGenerics) &&
195                ident.to_string() == "where" 
196            => {
197                self.curr_token_kind = TokenKind::Where;
198                self.location = ParseLocation::InWhere;
199                None
200            }
201            TokenTree::Punct(punct) if {
202                let c = punct.as_char();
203                c == ';' || c == '=' && punct.spacing() == Spacing::Alone
204            } => {
205                self.where_clause.extend(self.get_trailing_comma());
206                
207                self.after_where.extend(once(tt));
208                self.location = ParseLocation::AfterWhere;
209                
210                None
211            }
212            TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
213                self.where_clause.extend(self.get_trailing_comma());
214                
215                self.after_where.extend(once(tt));
216                self.location = ParseLocation::AfterWhere;
217                
218                None
219            }
220            _ => Some(tt),
221        }
222    }
223
224    fn get_trailing_comma(&self) -> Option<TokenTree> {
225        if let (ParseLocation::InWhere, TokenKind::Other) = (self.location, self.prev_token_kind) {
226            let mut p = Punct::new(',', Spacing::Alone);
227            p.set_span(self.last_span);
228            Some(TokenTree::Punct(p))
229        } else {
230            None
231        }
232    }
233
234    // Processes any pair of `<` and `>`
235    fn process_generic_list(&mut self, tt: TokenTree) -> Option<TokenTree> {
236        self.last_span = tt.span();
237        self.prev_is_joint = self.curr_is_joint;
238        self.curr_is_joint = false;
239
240        self.prev_token_kind = self.curr_token_kind;
241        self.curr_token_kind = TokenKind::Other;
242
243        if let TokenTree::Punct(punct) = &tt {
244            let char = punct.as_char();
245            self.curr_is_joint = char == '-' ||
246                punct.spacing() == Spacing::Joint && char != '>' && char != '<';
247
248            if char == ',' {
249                self.curr_token_kind = TokenKind::Comma;
250            }
251
252            if char == '<' {
253                self.depth += 1;
254            } if !self.prev_is_joint && char == '>' {
255                if self.depth == 0 {
256                    if mmatches!(self.location, ParseLocation::InGenerics) {
257                        return None;
258                    } 
259                } else {
260                    self.depth -= 1;
261                }
262            }
263        }
264
265        Some(tt)
266    }
267}
268
269
270#[derive(Copy, Clone)]
271enum ParseLocation {
272    InGenerics,
273    AfterGenerics,
274    InWhere,
275    AfterWhere,
276}
277
278
279#[derive(Copy, Clone)]
280enum TokenKind{
281    Where,
282    Comma,
283    Other,
284}
285
286
287