1use proc_macro2::{Span, TokenStream, TokenTree};
2use quote::{quote, quote_spanned, ToTokens};
3use syn::parse::{Parse, ParseStream, Parser};
4use syn::{braced, Attribute, Ident, Path, Signature, Visibility};
5
6type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
8
9#[derive(Clone, Copy, PartialEq)]
10enum RuntimeFlavor {
11 CurrentThread,
12 Threaded,
13 Local,
14}
15
16impl RuntimeFlavor {
17 fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
18 match s {
19 "current_thread" => Ok(RuntimeFlavor::CurrentThread),
20 "multi_thread" => Ok(RuntimeFlavor::Threaded),
21 "local" => Ok(RuntimeFlavor::Local),
22 "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()),
23 "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()),
24 "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()),
25 _ => Err(format!("No such runtime flavor `{s}`. The runtime flavors are `current_thread`, `local`, and `multi_thread`.")),
26 }
27 }
28}
29
30#[derive(Clone, Copy, PartialEq)]
31enum UnhandledPanic {
32 Ignore,
33 ShutdownRuntime,
34}
35
36impl UnhandledPanic {
37 fn from_str(s: &str) -> Result<UnhandledPanic, String> {
38 match s {
39 "ignore" => Ok(UnhandledPanic::Ignore),
40 "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime),
41 _ => Err(format!("No such unhandled panic behavior `{s}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`.")),
42 }
43 }
44
45 fn into_tokens(self, crate_path: &TokenStream) -> TokenStream {
46 match self {
47 UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore },
48 UnhandledPanic::ShutdownRuntime => {
49 quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime }
50 }
51 }
52 }
53}
54
55struct FinalConfig {
56 flavor: RuntimeFlavor,
57 worker_threads: Option<usize>,
58 start_paused: Option<bool>,
59 crate_name: Option<Path>,
60 unhandled_panic: Option<UnhandledPanic>,
61}
62
63const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig {
65 flavor: RuntimeFlavor::CurrentThread,
66 worker_threads: None,
67 start_paused: None,
68 crate_name: None,
69 unhandled_panic: None,
70};
71
72struct Configuration {
73 rt_multi_thread_available: bool,
74 default_flavor: RuntimeFlavor,
75 flavor: Option<RuntimeFlavor>,
76 worker_threads: Option<(usize, Span)>,
77 start_paused: Option<(bool, Span)>,
78 is_test: bool,
79 crate_name: Option<Path>,
80 unhandled_panic: Option<(UnhandledPanic, Span)>,
81}
82
83impl Configuration {
84 fn new(is_test: bool, rt_multi_thread: bool) -> Self {
85 Configuration {
86 rt_multi_thread_available: rt_multi_thread,
87 default_flavor: match is_test {
88 true => RuntimeFlavor::CurrentThread,
89 false => RuntimeFlavor::Threaded,
90 },
91 flavor: None,
92 worker_threads: None,
93 start_paused: None,
94 is_test,
95 crate_name: None,
96 unhandled_panic: None,
97 }
98 }
99
100 fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
101 if self.flavor.is_some() {
102 return Err(syn::Error::new(span, "`flavor` set multiple times."));
103 }
104
105 let runtime_str = parse_string(runtime, span, "flavor")?;
106 let runtime =
107 RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
108 self.flavor = Some(runtime);
109 Ok(())
110 }
111
112 fn set_worker_threads(
113 &mut self,
114 worker_threads: syn::Lit,
115 span: Span,
116 ) -> Result<(), syn::Error> {
117 if self.worker_threads.is_some() {
118 return Err(syn::Error::new(
119 span,
120 "`worker_threads` set multiple times.",
121 ));
122 }
123
124 let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
125 if worker_threads == 0 {
126 return Err(syn::Error::new(span, "`worker_threads` may not be 0."));
127 }
128 self.worker_threads = Some((worker_threads, span));
129 Ok(())
130 }
131
132 fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
133 if self.start_paused.is_some() {
134 return Err(syn::Error::new(span, "`start_paused` set multiple times."));
135 }
136
137 let start_paused = parse_bool(start_paused, span, "start_paused")?;
138 self.start_paused = Some((start_paused, span));
139 Ok(())
140 }
141
142 fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> {
143 if self.crate_name.is_some() {
144 return Err(syn::Error::new(span, "`crate` set multiple times."));
145 }
146 let name_path = parse_path(name, span, "crate")?;
147 self.crate_name = Some(name_path);
148 Ok(())
149 }
150
151 fn set_unhandled_panic(
152 &mut self,
153 unhandled_panic: syn::Lit,
154 span: Span,
155 ) -> Result<(), syn::Error> {
156 if self.unhandled_panic.is_some() {
157 return Err(syn::Error::new(
158 span,
159 "`unhandled_panic` set multiple times.",
160 ));
161 }
162
163 let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic")?;
164 let unhandled_panic =
165 UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?;
166 self.unhandled_panic = Some((unhandled_panic, span));
167 Ok(())
168 }
169
170 fn macro_name(&self) -> &'static str {
171 if self.is_test {
172 "tokio::test"
173 } else {
174 "tokio::main"
175 }
176 }
177
178 fn build(&self) -> Result<FinalConfig, syn::Error> {
179 use RuntimeFlavor as F;
180
181 let flavor = self.flavor.unwrap_or(self.default_flavor);
182
183 let worker_threads = match (flavor, self.worker_threads) {
184 (F::CurrentThread | F::Local, Some((_, worker_threads_span))) => {
185 let msg = format!(
186 "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
187 self.macro_name(),
188 );
189 return Err(syn::Error::new(worker_threads_span, msg));
190 }
191 (F::CurrentThread | F::Local, None) => None,
192 (F::Threaded, worker_threads) if self.rt_multi_thread_available => {
193 worker_threads.map(|(val, _span)| val)
194 }
195 (F::Threaded, _) => {
196 let msg = if self.flavor.is_none() {
197 "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
198 } else {
199 "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
200 };
201 return Err(syn::Error::new(Span::call_site(), msg));
202 }
203 };
204
205 let start_paused = match (flavor, self.start_paused) {
206 (F::Threaded, Some((_, start_paused_span))) => {
207 let msg = format!(
208 "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
209 self.macro_name(),
210 );
211 return Err(syn::Error::new(start_paused_span, msg));
212 }
213 (F::CurrentThread | F::Local, Some((start_paused, _))) => Some(start_paused),
214 (_, None) => None,
215 };
216
217 let unhandled_panic = match (flavor, self.unhandled_panic) {
218 (F::Threaded, Some((_, unhandled_panic_span))) => {
219 let msg = format!(
220 "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
221 self.macro_name(),
222 );
223 return Err(syn::Error::new(unhandled_panic_span, msg));
224 }
225 (F::CurrentThread | F::Local, Some((unhandled_panic, _))) => Some(unhandled_panic),
226 (_, None) => None,
227 };
228
229 Ok(FinalConfig {
230 crate_name: self.crate_name.clone(),
231 flavor,
232 worker_threads,
233 start_paused,
234 unhandled_panic,
235 })
236 }
237}
238
239fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
240 match int {
241 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
242 Ok(value) => Ok(value),
243 Err(e) => Err(syn::Error::new(
244 span,
245 format!("Failed to parse value of `{field}` as integer: {e}"),
246 )),
247 },
248 _ => Err(syn::Error::new(
249 span,
250 format!("Failed to parse value of `{field}` as integer."),
251 )),
252 }
253}
254
255fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
256 match int {
257 syn::Lit::Str(s) => Ok(s.value()),
258 syn::Lit::Verbatim(s) => Ok(s.to_string()),
259 _ => Err(syn::Error::new(
260 span,
261 format!("Failed to parse value of `{field}` as string."),
262 )),
263 }
264}
265
266fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result<Path, syn::Error> {
267 match lit {
268 syn::Lit::Str(s) => {
269 let err = syn::Error::new(
270 span,
271 format!(
272 "Failed to parse value of `{}` as path: \"{}\"",
273 field,
274 s.value()
275 ),
276 );
277 s.parse::<syn::Path>().map_err(|_| err.clone())
278 }
279 _ => Err(syn::Error::new(
280 span,
281 format!("Failed to parse value of `{field}` as path."),
282 )),
283 }
284}
285
286fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
287 match bool {
288 syn::Lit::Bool(b) => Ok(b.value),
289 _ => Err(syn::Error::new(
290 span,
291 format!("Failed to parse value of `{field}` as bool."),
292 )),
293 }
294}
295
296fn build_config(
297 input: &ItemFn,
298 args: AttributeArgs,
299 is_test: bool,
300 rt_multi_thread: bool,
301) -> Result<FinalConfig, syn::Error> {
302 if input.sig.asyncness.is_none() {
303 let msg = "the `async` keyword is missing from the function declaration";
304 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
305 }
306
307 let mut config = Configuration::new(is_test, rt_multi_thread);
308 let macro_name = config.macro_name();
309
310 for arg in args {
311 match arg {
312 syn::Meta::NameValue(namevalue) => {
313 let ident = namevalue
314 .path
315 .get_ident()
316 .ok_or_else(|| {
317 syn::Error::new_spanned(&namevalue, "Must have specified ident")
318 })?
319 .to_string()
320 .to_lowercase();
321 let lit = match &namevalue.value {
322 syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit,
323 expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")),
324 };
325 match ident.as_str() {
326 "worker_threads" => {
327 config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?;
328 }
329 "flavor" => {
330 config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?;
331 }
332 "start_paused" => {
333 config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?;
334 }
335 "core_threads" => {
336 let msg = "Attribute `core_threads` is renamed to `worker_threads`";
337 return Err(syn::Error::new_spanned(namevalue, msg));
338 }
339 "crate" => {
340 config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?;
341 }
342 "unhandled_panic" => {
343 config
344 .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?;
345 }
346 name => {
347 let msg = format!(
348 "Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`",
349 );
350 return Err(syn::Error::new_spanned(namevalue, msg));
351 }
352 }
353 }
354 syn::Meta::Path(path) => {
355 let name = path
356 .get_ident()
357 .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
358 .to_string()
359 .to_lowercase();
360 let msg = match name.as_str() {
361 "threaded_scheduler" | "multi_thread" => {
362 format!(
363 "Set the runtime flavor with #[{macro_name}(flavor = \"multi_thread\")]."
364 )
365 }
366 "basic_scheduler" | "current_thread" | "single_threaded" => {
367 format!(
368 "Set the runtime flavor with #[{macro_name}(flavor = \"current_thread\")]."
369 )
370 }
371 "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => {
372 format!("The `{name}` attribute requires an argument.")
373 }
374 name => {
375 format!("Unknown attribute {name} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`.")
376 }
377 };
378 return Err(syn::Error::new_spanned(path, msg));
379 }
380 other => {
381 return Err(syn::Error::new_spanned(
382 other,
383 "Unknown attribute inside the macro",
384 ));
385 }
386 }
387 }
388
389 config.build()
390}
391
392fn parse_knobs(mut input: ItemFn, is_test: bool, config: FinalConfig) -> TokenStream {
393 input.sig.asyncness = None;
394
395 let (last_stmt_start_span, last_stmt_end_span) = {
397 let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter();
398
399 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
404 let end = last_stmt.last().map_or(start, |t| t.span());
405 (start, end)
406 };
407
408 let crate_path = config
409 .crate_name
410 .map(ToTokens::into_token_stream)
411 .unwrap_or_else(|| Ident::new("tokio", last_stmt_start_span).into_token_stream());
412
413 let mut rt = match config.flavor {
414 RuntimeFlavor::CurrentThread | RuntimeFlavor::Local => {
415 quote_spanned! {last_stmt_start_span=>
416 #crate_path::runtime::Builder::new_current_thread()
417 }
418 }
419 RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=>
420 #crate_path::runtime::Builder::new_multi_thread()
421 },
422 };
423
424 let mut checks = vec![];
425 let mut errors = vec![];
426
427 let build = if let RuntimeFlavor::Local = config.flavor {
428 checks.push(quote! { tokio_unstable });
429 errors.push("The local runtime flavor is only available when `tokio_unstable` is set.");
430 quote_spanned! {last_stmt_start_span=> build_local(Default::default())}
431 } else {
432 quote_spanned! {last_stmt_start_span=> build()}
433 };
434
435 if let Some(v) = config.worker_threads {
436 rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) };
437 }
438 if let Some(v) = config.start_paused {
439 rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) };
440 }
441 if let Some(v) = config.unhandled_panic {
442 let unhandled_panic = v.into_tokens(&crate_path);
443 rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) };
444 }
445
446 let generated_attrs = if is_test {
447 quote! {
448 #[::core::prelude::v1::test]
449 }
450 } else {
451 quote! {}
452 };
453
454 let do_checks: TokenStream = checks
455 .iter()
456 .zip(&errors)
457 .map(|(check, error)| {
458 quote! {
459 #[cfg(not(#check))]
460 compile_error!(#error);
461 }
462 })
463 .collect();
464
465 let body_ident = quote! { body };
466 let last_block = quote_spanned! {last_stmt_end_span=>
468 #do_checks
469
470 #[cfg(all(#(#checks),*))]
471 #[allow(clippy::expect_used, clippy::diverging_sub_expression, clippy::needless_return, clippy::unwrap_in_result)]
472 {
473 return #rt
474 .enable_all()
475 .#build
476 .expect("Failed building the Runtime")
477 .block_on(#body_ident);
478 }
479
480 #[cfg(not(all(#(#checks),*)))]
481 {
482 panic!("fell through checks")
483 }
484 };
485
486 let body = input.body();
487
488 let body = if is_test {
498 let output_type = match &input.sig.output {
499 syn::ReturnType::Default => quote! { () },
503 syn::ReturnType::Type(_, ret_type) => quote! { #ret_type },
504 };
505 quote! {
506 let body = async #body;
507 #crate_path::pin!(body);
508 let body: ::core::pin::Pin<&mut dyn ::core::future::Future<Output = #output_type>> = body;
509 }
510 } else {
511 quote! {
512 let body = async #body;
513 }
514 };
515
516 input.into_tokens(generated_attrs, body, last_block)
517}
518
519fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream {
520 tokens.extend(error.into_compile_error());
521 tokens
522}
523
524pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
525 let input: ItemFn = match syn::parse2(item.clone()) {
529 Ok(it) => it,
530 Err(e) => return token_stream_with_error(item, e),
531 };
532
533 let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() {
534 let msg = "the main function cannot accept arguments";
535 Err(syn::Error::new_spanned(&input.sig.ident, msg))
536 } else {
537 AttributeArgs::parse_terminated
538 .parse2(args)
539 .and_then(|args| build_config(&input, args, false, rt_multi_thread))
540 };
541
542 match config {
543 Ok(config) => parse_knobs(input, false, config),
544 Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e),
545 }
546}
547
548fn is_test_attribute(attr: &Attribute) -> bool {
553 let path = match &attr.meta {
554 syn::Meta::Path(path) => path,
555 _ => return false,
556 };
557 let candidates = [
558 ["core", "prelude", "*", "test"],
559 ["std", "prelude", "*", "test"],
560 ];
561 if path.leading_colon.is_none()
562 && path.segments.len() == 1
563 && path.segments[0].arguments.is_none()
564 && path.segments[0].ident == "test"
565 {
566 return true;
567 } else if path.segments.len() != candidates[0].len() {
568 return false;
569 }
570 candidates.into_iter().any(|segments| {
571 path.segments.iter().zip(segments).all(|(segment, path)| {
572 segment.arguments.is_none() && (path == "*" || segment.ident == path)
573 })
574 })
575}
576
577pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream {
578 let input: ItemFn = match syn::parse2(item.clone()) {
582 Ok(it) => it,
583 Err(e) => return token_stream_with_error(item, e),
584 };
585 let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) {
586 let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes";
587 Err(syn::Error::new_spanned(attr, msg))
588 } else {
589 AttributeArgs::parse_terminated
590 .parse2(args)
591 .and_then(|args| build_config(&input, args, true, rt_multi_thread))
592 };
593
594 match config {
595 Ok(config) => parse_knobs(input, true, config),
596 Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e),
597 }
598}
599
600struct ItemFn {
601 outer_attrs: Vec<Attribute>,
602 vis: Visibility,
603 sig: Signature,
604 brace_token: syn::token::Brace,
605 inner_attrs: Vec<Attribute>,
606 stmts: Vec<proc_macro2::TokenStream>,
607}
608
609impl ItemFn {
610 fn attrs(&self) -> impl Iterator<Item = &Attribute> {
612 self.outer_attrs.iter().chain(self.inner_attrs.iter())
613 }
614
615 fn body(&self) -> Body<'_> {
618 Body {
619 brace_token: self.brace_token,
620 stmts: &self.stmts,
621 }
622 }
623
624 fn into_tokens(
626 self,
627 generated_attrs: proc_macro2::TokenStream,
628 body: proc_macro2::TokenStream,
629 last_block: proc_macro2::TokenStream,
630 ) -> TokenStream {
631 let mut tokens = proc_macro2::TokenStream::new();
632 for attr in self.outer_attrs {
634 attr.to_tokens(&mut tokens);
635 }
636
637 for mut attr in self.inner_attrs {
641 attr.style = syn::AttrStyle::Outer;
642 attr.to_tokens(&mut tokens);
643 }
644
645 generated_attrs.to_tokens(&mut tokens);
647
648 self.vis.to_tokens(&mut tokens);
649 self.sig.to_tokens(&mut tokens);
650
651 self.brace_token.surround(&mut tokens, |tokens| {
652 body.to_tokens(tokens);
653 last_block.to_tokens(tokens);
654 });
655
656 tokens
657 }
658}
659
660impl Parse for ItemFn {
661 #[inline]
662 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
663 let outer_attrs = input.call(Attribute::parse_outer)?;
671 let vis: Visibility = input.parse()?;
672 let sig: Signature = input.parse()?;
673
674 let content;
675 let brace_token = braced!(content in input);
676 let inner_attrs = Attribute::parse_inner(&content)?;
677
678 let mut buf = proc_macro2::TokenStream::new();
679 let mut stmts = Vec::new();
680
681 while !content.is_empty() {
682 if let Some(semi) = content.parse::<Option<syn::Token![;]>>()? {
683 semi.to_tokens(&mut buf);
684 stmts.push(buf);
685 buf = proc_macro2::TokenStream::new();
686 continue;
687 }
688
689 buf.extend([content.parse::<TokenTree>()?]);
692 }
693
694 if !buf.is_empty() {
695 stmts.push(buf);
696 }
697
698 Ok(Self {
699 outer_attrs,
700 vis,
701 sig,
702 brace_token,
703 inner_attrs,
704 stmts,
705 })
706 }
707}
708
709struct Body<'a> {
710 brace_token: syn::token::Brace,
711 stmts: &'a [TokenStream],
713}
714
715impl ToTokens for Body<'_> {
716 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
717 self.brace_token.surround(tokens, |tokens| {
718 for stmt in self.stmts {
719 stmt.to_tokens(tokens);
720 }
721 });
722 }
723}