core_extensions_proc_macros/
splitting_generics.rs
1use crate::{
2 used_proc_macro::{
3 token_stream::IntoIter,
4 Delimiter, Punct, Spacing, Span, TokenStream, TokenTree
5 },
6 parsing_shared::{MacroInvocation, out_parenthesized, parse_paren_args, parse_path_and_args},
7 mmatches,
8};
9
10use core::iter::{Peekable, once};
11
12use alloc::string::ToString;
13
14
15
16
17pub(crate) trait PostGenericsParser {
18 fn consume_token(&mut self, _: &SplitGenerics, tt: TokenTree);
19 fn write_tokens(self, ts: &mut TokenStream);
20}
21
22
23
24pub(crate) struct SplitGenerics {
25 input_tokens: IntoIter,
27 parsing: Peekable<IntoIter>,
29 curr_is_joint: bool,
30 prev_is_joint: bool,
31 curr_token_kind: TokenKind,
32 prev_token_kind: TokenKind,
33 location: ParseLocation,
34 depth: u32,
35 last_span: Span,
36 generics: TokenStream,
37 generics_span: Span,
38 where_clause: TokenStream,
39 where_clause_span: Span,
40 after_where: TokenStream,
41 after_where_span: Span,
42}
43
44
45impl SplitGenerics {
46 pub(crate) fn new<I>(input_tokens: I) -> Self
47 where
48 I: IntoIterator<IntoIter = IntoIter, Item = TokenTree>
49 {
50 let mut input_tokens = input_tokens.into_iter();
51
52 let parsed_tt = input_tokens.next().expect("skip_generics expected more tokens");
53
54 let parsing = parse_paren_args(&parsed_tt);
55
56 Self::some_consumed(input_tokens, parsing)
57 }
58
59 pub(crate) fn some_consumed(input_tokens: IntoIter, parsing: Peekable<IntoIter>) -> Self {
60 Self {
61 input_tokens,
62 parsing,
63 curr_is_joint: false,
64 prev_is_joint: false,
65 curr_token_kind: TokenKind::Other,
66 prev_token_kind: TokenKind::Other,
67 depth: 0,
68 location: ParseLocation::InGenerics,
69 last_span: Span::call_site(),
70 generics: TokenStream::new(),
71 generics_span: Span::call_site(),
72 where_clause: TokenStream::new(),
73 where_clause_span: Span::call_site(),
74 after_where: TokenStream::new(),
75 after_where_span: Span::call_site(),
76 }
77 }
78
79 #[allow(dead_code)]
80 pub(crate) fn curr_is_joint(&self) -> bool {
81 self.curr_is_joint
82 }
83
84 #[allow(dead_code)]
85 pub(crate) fn prev_is_joint(&self) -> bool {
86 self.prev_is_joint
87 }
88
89 #[allow(dead_code)]
90 pub(crate) fn depth(&self) -> u32 {
91 self.depth
92 }
93
94 #[allow(dead_code)]
95 pub(crate) fn last_span(&self) -> Span {
96 self.last_span
97 }
98}
99
100macro_rules! match_process_gen {
101 ($res:expr, $tt:ident) => {
102 match $res {
103 Some(tt)=> $tt = tt,
104 None => break,
105 }
106 };
107}
108
109impl SplitGenerics {
110 pub(crate) fn split_generics<P>(mut self, callback_macro: MacroInvocation, args: TokenStream,mut parsing_pgen: P) -> TokenStream
111 where
112 P: PostGenericsParser
113 {
114 self.process_generics();
115
116 self.location = ParseLocation::AfterGenerics;
117
118 if self.depth == 0 {
119 while let Some(mut tt) = self.parsing.next() {
120 match_process_gen!(self.process_generic_list(tt), tt);
121
122 if self.depth == 0 {
123 match_process_gen!(self.process_after_generics(tt), tt);
124 }
125
126 parsing_pgen.consume_token(&self, tt);
127 }
128 }
129
130 self.process_from_where_clause();
131
132 let Self{
133 mut input_tokens,
134 generics, generics_span,
135 where_clause, where_clause_span,
136 after_where, after_where_span,
137 ..
138 } = self;
139
140 callback_macro.expand_with_extra_args(|out_args| {
141 out_args.extend(args);
142
143 out_parenthesized(generics, generics_span, out_args);
144
145 parsing_pgen.write_tokens(out_args);
146
147 out_parenthesized(where_clause, where_clause_span, out_args);
148 out_parenthesized(after_where, after_where_span, out_args);
149 })
150 }
151
152 fn process_generics(&mut self) {
155 if mmatches!(
156 self.parsing.peek(),
157 Some(TokenTree::Punct(punct)) if punct.as_char() == '<'
158 ) {
159 drop(self.parsing.next());
160 while let Some(mut tt) = self.parsing.next() {
161 match_process_gen!(self.process_generic_list(tt), tt);
162 self.generics.extend(once(tt));
163 }
164 self.generics_span = self.last_span;
165 }
166 }
167
168 fn process_from_where_clause(&mut self) {
169 if self.depth == 0 && mmatches!(self.location, ParseLocation::InWhere) {
170 while let Some(mut tt) = self.parsing.next() {
171 match_process_gen!(self.process_generic_list(tt), tt);
172
173 if self.depth == 0 {
174 match_process_gen!(self.process_after_generics(tt), tt);
175 }
176
177 self.where_clause.extend(once(tt));
178 }
179 }
180
181 self.where_clause_span = self.last_span;
182
183
184 for tt in &mut self.parsing {
185 self.last_span = tt.span();
186 self.after_where.extend(once(tt));
187 }
188 self.after_where_span = self.last_span;
189 }
190
191 fn process_after_generics(&mut self, tt: TokenTree) -> Option<TokenTree> {
192 match &tt {
193 TokenTree::Ident(ident) if
194 mmatches!(self.location, ParseLocation::AfterGenerics) &&
195 ident.to_string() == "where"
196 => {
197 self.curr_token_kind = TokenKind::Where;
198 self.location = ParseLocation::InWhere;
199 None
200 }
201 TokenTree::Punct(punct) if {
202 let c = punct.as_char();
203 c == ';' || c == '=' && punct.spacing() == Spacing::Alone
204 } => {
205 self.where_clause.extend(self.get_trailing_comma());
206
207 self.after_where.extend(once(tt));
208 self.location = ParseLocation::AfterWhere;
209
210 None
211 }
212 TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
213 self.where_clause.extend(self.get_trailing_comma());
214
215 self.after_where.extend(once(tt));
216 self.location = ParseLocation::AfterWhere;
217
218 None
219 }
220 _ => Some(tt),
221 }
222 }
223
224 fn get_trailing_comma(&self) -> Option<TokenTree> {
225 if let (ParseLocation::InWhere, TokenKind::Other) = (self.location, self.prev_token_kind) {
226 let mut p = Punct::new(',', Spacing::Alone);
227 p.set_span(self.last_span);
228 Some(TokenTree::Punct(p))
229 } else {
230 None
231 }
232 }
233
234 fn process_generic_list(&mut self, tt: TokenTree) -> Option<TokenTree> {
236 self.last_span = tt.span();
237 self.prev_is_joint = self.curr_is_joint;
238 self.curr_is_joint = false;
239
240 self.prev_token_kind = self.curr_token_kind;
241 self.curr_token_kind = TokenKind::Other;
242
243 if let TokenTree::Punct(punct) = &tt {
244 let char = punct.as_char();
245 self.curr_is_joint = char == '-' ||
246 punct.spacing() == Spacing::Joint && char != '>' && char != '<';
247
248 if char == ',' {
249 self.curr_token_kind = TokenKind::Comma;
250 }
251
252 if char == '<' {
253 self.depth += 1;
254 } if !self.prev_is_joint && char == '>' {
255 if self.depth == 0 {
256 if mmatches!(self.location, ParseLocation::InGenerics) {
257 return None;
258 }
259 } else {
260 self.depth -= 1;
261 }
262 }
263 }
264
265 Some(tt)
266 }
267}
268
269
270#[derive(Copy, Clone)]
271enum ParseLocation {
272 InGenerics,
273 AfterGenerics,
274 InWhere,
275 AfterWhere,
276}
277
278
279#[derive(Copy, Clone)]
280enum TokenKind{
281 Where,
282 Comma,
283 Other,
284}
285
286
287