core_extensions_proc_macros/
splitting_generics.rs1use crate::{
2 used_proc_macro::{
3 token_stream::IntoIter,
4 Delimiter, Punct, Spacing, Span, TokenStream, TokenTree
5 },
6 parsing_shared::{MacroInvocation, out_parenthesized, parse_paren_args},
7 mmatches,
8};
9
10use core::iter::{Peekable, once};
11
12use alloc::string::ToString;
13
14
15
16
17pub(crate) trait PostGenericsParser {
18 fn consume_token(&mut self, _: &SplitGenerics, tt: TokenTree);
19 fn write_tokens(self, ts: &mut TokenStream);
20}
21
22
23
24pub(crate) struct SplitGenerics {
25 parsing: Peekable<IntoIter>,
27 curr_is_joint: bool,
28 prev_is_joint: bool,
29 curr_token_kind: TokenKind,
30 prev_token_kind: TokenKind,
31 location: ParseLocation,
32 depth: u32,
33 last_span: Span,
34 generics: TokenStream,
35 generics_span: Span,
36 where_clause: TokenStream,
37 where_clause_span: Span,
38 after_where: TokenStream,
39 after_where_span: Span,
40}
41
42
43impl SplitGenerics {
44 pub(crate) fn new<I>(input_tokens: I) -> Self
45 where
46 I: IntoIterator<IntoIter = IntoIter, Item = TokenTree>
47 {
48 let mut input_tokens = input_tokens.into_iter();
49
50 let parsed_tt = input_tokens.next().expect("skip_generics expected more tokens");
51
52 let parsing = parse_paren_args(&parsed_tt);
53
54 Self::some_consumed(parsing)
55 }
56
57 pub(crate) fn some_consumed(parsing: Peekable<IntoIter>) -> Self {
58 Self {
59 parsing,
60 curr_is_joint: false,
61 prev_is_joint: false,
62 curr_token_kind: TokenKind::Other,
63 prev_token_kind: TokenKind::Other,
64 depth: 0,
65 location: ParseLocation::InGenerics,
66 last_span: Span::call_site(),
67 generics: TokenStream::new(),
68 generics_span: Span::call_site(),
69 where_clause: TokenStream::new(),
70 where_clause_span: Span::call_site(),
71 after_where: TokenStream::new(),
72 after_where_span: Span::call_site(),
73 }
74 }
75
76 #[allow(dead_code)]
77 pub(crate) fn curr_is_joint(&self) -> bool {
78 self.curr_is_joint
79 }
80
81 #[allow(dead_code)]
82 pub(crate) fn prev_is_joint(&self) -> bool {
83 self.prev_is_joint
84 }
85
86 #[allow(dead_code)]
87 pub(crate) fn depth(&self) -> u32 {
88 self.depth
89 }
90
91 #[allow(dead_code)]
92 pub(crate) fn last_span(&self) -> Span {
93 self.last_span
94 }
95}
96
97macro_rules! match_process_gen {
98 ($res:expr, $tt:ident) => {
99 match $res {
100 Some(tt)=> $tt = tt,
101 None => break,
102 }
103 };
104}
105
106impl SplitGenerics {
107 pub(crate) fn split_generics<P>(mut self, callback_macro: MacroInvocation, args: TokenStream,mut parsing_pgen: P) -> TokenStream
108 where
109 P: PostGenericsParser
110 {
111 self.process_generics();
112
113 self.location = ParseLocation::AfterGenerics;
114
115 if self.depth == 0 {
116 while let Some(mut tt) = self.parsing.next() {
117 match_process_gen!(self.process_generic_list(tt), tt);
118
119 if self.depth == 0 {
120 match_process_gen!(self.process_after_generics(tt), tt);
121 }
122
123 parsing_pgen.consume_token(&self, tt);
124 }
125 }
126
127 self.process_from_where_clause();
128
129 let Self{
130 generics, generics_span,
131 where_clause, where_clause_span,
132 after_where, after_where_span,
133 ..
134 } = self;
135
136 callback_macro.expand_with_extra_args(|out_args| {
137 out_args.extend(args);
138
139 out_parenthesized(generics, generics_span, out_args);
140
141 parsing_pgen.write_tokens(out_args);
142
143 out_parenthesized(where_clause, where_clause_span, out_args);
144 out_parenthesized(after_where, after_where_span, out_args);
145 })
146 }
147
148 fn process_generics(&mut self) {
151 if mmatches!(
152 self.parsing.peek(),
153 Some(TokenTree::Punct(punct)) if punct.as_char() == '<'
154 ) {
155 drop(self.parsing.next());
156 while let Some(mut tt) = self.parsing.next() {
157 match_process_gen!(self.process_generic_list(tt), tt);
158 self.generics.extend(once(tt));
159 }
160 self.generics_span = self.last_span;
161 }
162 }
163
164 fn process_from_where_clause(&mut self) {
165 if self.depth == 0 && mmatches!(self.location, ParseLocation::InWhere) {
166 while let Some(mut tt) = self.parsing.next() {
167 match_process_gen!(self.process_generic_list(tt), tt);
168
169 if self.depth == 0 {
170 match_process_gen!(self.process_after_generics(tt), tt);
171 }
172
173 self.where_clause.extend(once(tt));
174 }
175 }
176
177 self.where_clause_span = self.last_span;
178
179
180 for tt in &mut self.parsing {
181 self.last_span = tt.span();
182 self.after_where.extend(once(tt));
183 }
184 self.after_where_span = self.last_span;
185 }
186
187 fn process_after_generics(&mut self, tt: TokenTree) -> Option<TokenTree> {
188 match &tt {
189 TokenTree::Ident(ident) if
190 mmatches!(self.location, ParseLocation::AfterGenerics) &&
191 ident.to_string() == "where"
192 => {
193 self.curr_token_kind = TokenKind::Where;
194 self.location = ParseLocation::InWhere;
195 None
196 }
197 TokenTree::Punct(punct) if {
198 let c = punct.as_char();
199 c == ';' || c == '=' && punct.spacing() == Spacing::Alone
200 } => {
201 self.where_clause.extend(self.get_trailing_comma());
202
203 self.after_where.extend(once(tt));
204 self.location = ParseLocation::AfterWhere;
205
206 None
207 }
208 TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => {
209 self.where_clause.extend(self.get_trailing_comma());
210
211 self.after_where.extend(once(tt));
212 self.location = ParseLocation::AfterWhere;
213
214 None
215 }
216 _ => Some(tt),
217 }
218 }
219
220 fn get_trailing_comma(&self) -> Option<TokenTree> {
221 if let (ParseLocation::InWhere, TokenKind::Other) = (self.location, self.prev_token_kind) {
222 let mut p = Punct::new(',', Spacing::Alone);
223 p.set_span(self.last_span);
224 Some(TokenTree::Punct(p))
225 } else {
226 None
227 }
228 }
229
230 fn process_generic_list(&mut self, tt: TokenTree) -> Option<TokenTree> {
232 self.last_span = tt.span();
233 self.prev_is_joint = self.curr_is_joint;
234 self.curr_is_joint = false;
235
236 self.prev_token_kind = self.curr_token_kind;
237 self.curr_token_kind = TokenKind::Other;
238
239 if let TokenTree::Punct(punct) = &tt {
240 let char = punct.as_char();
241 self.curr_is_joint = char == '-' ||
242 punct.spacing() == Spacing::Joint && char != '>' && char != '<';
243
244 if char == ',' {
245 self.curr_token_kind = TokenKind::Comma;
246 }
247
248 if char == '<' {
249 self.depth += 1;
250 } if !self.prev_is_joint && char == '>' {
251 if self.depth == 0 {
252 if mmatches!(self.location, ParseLocation::InGenerics) {
253 return None;
254 }
255 } else {
256 self.depth -= 1;
257 }
258 }
259 }
260
261 Some(tt)
262 }
263}
264
265
266#[derive(Copy, Clone)]
267enum ParseLocation {
268 InGenerics,
269 AfterGenerics,
270 InWhere,
271 AfterWhere,
272}
273
274
275#[derive(Copy, Clone)]
276enum TokenKind{
277 Where,
278 Comma,
279 Other,
280}
281
282
283