abi_stable_derive/
fn_pointer_extractor.rs

1//! Contains visitor type for
2//! extracting function pointers and the referenced lifetimes of the fields of a type declaration.
3
4use std::{collections::HashSet, mem};
5
6use as_derive_utils::{spanned_err, syn_err};
7
8use core_extensions::SelfOps;
9
10use syn::{
11    spanned::Spanned,
12    visit_mut::{self, VisitMut},
13    Generics, Ident, Lifetime, Type, TypeBareFn, TypeReference,
14};
15
16use proc_macro2::Span;
17
18use quote::ToTokens;
19
20use crate::*;
21use crate::{
22    common_tokens::FnPointerTokens,
23    ignored_wrapper::Ignored,
24    lifetimes::{LifetimeCounters, LifetimeIndex},
25    utils::{LinearResult, SynResultExt},
26};
27
28/// Information about all the function pointers in a type declaration.
29#[derive(Clone, Debug, PartialEq, Hash)]
30pub(crate) struct FnInfo<'a> {
31    /// The generics of the struct this function pointer type is used inside of.
32    parent_generics: &'a Generics,
33
34    /// The identifiers for all the lifetimes of the
35    /// struct this function pointer type is used inside of
36    env_lifetimes: Vec<&'a Ident>,
37
38    /// The index of first lifetime declared by all functions.
39    /// (with higher lifetime indices from the struct/enum definition it is used inside of).
40    initial_bound_lifetime: usize,
41
42    pub functions: Vec<Function<'a>>,
43}
44
45/// A function pointer in a type declaration.
46#[derive(Clone, Debug, PartialEq, Hash)]
47pub(crate) struct Function<'a> {
48    pub(crate) fn_token: syn::Token!(fn),
49    pub(crate) func_span: Ignored<Span>,
50
51    /// The index of the first lifetime the function declares,if there are any.
52    pub(crate) first_bound_lt: usize,
53
54    /// The index of the first unnamed lifetime of the function,if there are any.
55    pub(crate) first_unnamed_bound_lt: usize,
56
57    /// The named lifetimes for this function pointer type,
58    /// the ones declared within `for<'a,'b,'c>`.
59    pub(crate) named_bound_lts: Vec<&'a Ident>,
60    /// A set version of the `named_bound_lts` field.
61    pub(crate) named_bound_lt_set: Ignored<HashSet<&'a Ident>>,
62    /// The amount of lifetimes declared by the function pointer.
63    pub(crate) bound_lts_count: usize,
64
65    pub(crate) is_unsafe: bool,
66
67    /// The Span for the first time that a bound lifetime appears in the type definition.
68    pub(crate) bound_lt_spans: Ignored<Vec<Option<Span>>>,
69
70    /// The parameters of this function pointer,including name and type.
71    pub(crate) params: Vec<FnParamRet<'a>>,
72    /// What this function pointer returns,including name and type.
73    ///
74    /// None if its return type is `()`.
75    pub(crate) returns: Option<FnParamRet<'a>>,
76}
77
78#[derive(Clone, Debug, PartialEq, Hash)]
79pub(crate) struct FnParamRet<'a> {
80    /// The name of the parameter/return type.
81    ///
82    /// This is None if the parameter doesn't have a name.
83    pub(crate) name: Option<&'a Ident>,
84    /// The lifetimes this type references (including static).
85    pub(crate) lifetime_refs: Vec<LifetimeIndex>,
86    /// The type of the parameter/return type.
87    pub(crate) ty: &'a Type,
88    /// Whether this is a parameter or a return type.
89    pub(crate) param_or_ret: ParamOrReturn,
90}
91
92/// The information returned from visiting a field.
93pub(crate) struct VisitFieldRet<'a> {
94    /// The lifetimes that the field references.
95    pub(crate) referenced_lifetimes: Vec<LifetimeIndex>,
96    /// The function pointer types in the field.
97    pub(crate) functions: Vec<Function<'a>>,
98}
99
100/////////////
101
102#[allow(dead_code)]
103impl<'a> TypeVisitor<'a> {
104    #[inline(never)]
105    pub fn new(arenas: &'a Arenas, ctokens: &'a FnPointerTokens, generics: &'a Generics) -> Self {
106        TypeVisitor {
107            refs: ImmutableRefs {
108                arenas,
109                ctokens,
110                env_generics: generics,
111            },
112            vars: Vars {
113                allow_type_macros: false,
114                referenced_lifetimes: Vec::default(),
115                fn_info: FnInfo {
116                    parent_generics: generics,
117                    env_lifetimes: generics.lifetimes().map(|lt| &lt.lifetime.ident).collect(),
118                    initial_bound_lifetime: generics.lifetimes().count(),
119                    functions: Vec::new(),
120                },
121                errors: LinearResult::ok(()),
122            },
123        }
124    }
125
126    pub fn allow_type_macros(&mut self) {
127        self.vars.allow_type_macros = true;
128    }
129
130    /// Gets the arena this references.
131    pub fn arenas(&self) -> &'a Arenas {
132        self.refs.arenas
133    }
134    /// Gets the CommonTokens this references.
135    pub fn ctokens(&self) -> &'a FnPointerTokens {
136        self.refs.ctokens
137    }
138    /// Gets the generic parameters this references.
139    pub fn env_generics(&self) -> &'a Generics {
140        self.refs.env_generics
141    }
142
143    /// Visit a field type,
144    /// returning the function pointer types and referenced lifetimes.
145    pub fn visit_field(&mut self, ty: &mut Type) -> VisitFieldRet<'a> {
146        self.visit_type_mut(ty);
147        VisitFieldRet {
148            referenced_lifetimes: mem::take(&mut self.vars.referenced_lifetimes),
149            functions: mem::take(&mut self.vars.fn_info.functions),
150        }
151    }
152
153    pub fn get_errors(&mut self) -> Result<(), syn::Error> {
154        self.vars.errors.take()
155    }
156
157    pub fn into_fn_info(self) -> FnInfo<'a> {
158        self.vars.fn_info
159    }
160}
161
162/////////////
163
164/// Whether this is a parameter or a return type/value.
165#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
166pub(crate) enum ParamOrReturn {
167    Param,
168    Return,
169}
170
171/// A type which visits an entire type definition a field at a time,
172/// extracting function pointers and lifetimes for each field.
173pub(crate) struct TypeVisitor<'a> {
174    /// immutable references shared with other data structures-
175    refs: ImmutableRefs<'a>,
176    /// variables which are mutated when visiting.
177    vars: Vars<'a>,
178}
179
180/// Some immutable references used when visiting field types.
181#[allow(dead_code)]
182#[derive(Copy, Clone)]
183struct ImmutableRefs<'a> {
184    arenas: &'a Arenas,
185    ctokens: &'a FnPointerTokens,
186    /// Generics provided by the environment (eg:the struct this type is used inside of).
187    env_generics: &'a Generics,
188}
189
190/// variables which are mutated when visiting.
191struct Vars<'a> {
192    allow_type_macros: bool,
193    /// What lifetimes in env_lifetimes are referenced in the type being visited.
194    /// For TLField.
195    referenced_lifetimes: Vec<LifetimeIndex>,
196    fn_info: FnInfo<'a>,
197    errors: LinearResult<()>,
198}
199
200/// Used to visit a function pointer type.
201struct FnVisitor<'a, 'b> {
202    refs: ImmutableRefs<'a>,
203    vars: &'b mut Vars<'a>,
204
205    lifetime_counts: LifetimeCounters,
206
207    /// The current function pointer type that is being visited,
208    current: Function<'a>,
209    /// The lifetime indices inside a parameter/return type that is currently being visited.
210    param_ret: FnParamRetLifetimes,
211}
212
213/// The lifetime indices inside a parameter/return type.
214struct FnParamRetLifetimes {
215    span: Span,
216    /// The lifetimes this type references (including static).
217    lifetime_refs: Vec<LifetimeIndex>,
218    /// Whether this is a parameter or return type.
219    param_or_ret: ParamOrReturn,
220}
221
222/////////////
223
224impl FnParamRetLifetimes {
225    fn new(param_or_ret: ParamOrReturn, span: Option<Span>) -> Self {
226        Self {
227            span: span.unwrap_or_else(Span::call_site),
228            lifetime_refs: Vec::new(),
229            param_or_ret,
230        }
231    }
232}
233
234/////////////
235
236impl<'a> Vars<'a> {
237    /// Registers a lifetime index,
238    /// selecting those that come from the type declaration itself.
239    pub fn add_referenced_env_lifetime(&mut self, ind: LifetimeIndex) {
240        let is_env_lt = match (ind, ind.to_param()) {
241            (LifetimeIndex::STATIC, _) => true,
242            (_, Some(index)) => index < self.fn_info.env_lifetimes.len(),
243            _ => false,
244        };
245        if is_env_lt {
246            self.referenced_lifetimes.push(ind);
247        }
248    }
249}
250
251/////////////
252
253impl<'a> VisitMut for TypeVisitor<'a> {
254    /// Visits a function pointer type within a field type.
255    #[inline(never)]
256    fn visit_type_bare_fn_mut(&mut self, func: &mut TypeBareFn) {
257        let ctokens = self.refs.ctokens;
258        let arenas = self.refs.arenas;
259
260        let func_span = func.span();
261
262        let is_unsafe = func.unsafety.is_some();
263
264        let abi = func.abi.as_ref().map(|x| x.name.as_ref());
265        const ABI_ERR: &str = "must write `extern \"C\" fn` for function pointer types.";
266        match abi {
267            Some(Some(abi)) if *abi == ctokens.c_abi_lit => {}
268            Some(Some(abi)) => {
269                self.vars
270                    .errors
271                    .push_err(spanned_err!(abi, "Abi not supported for function pointers",));
272                return;
273            }
274            Some(None) => {}
275            None => {
276                self.vars.errors.push_err(spanned_err!(
277                    func,
278                    "The default abi is not supported for function pointers,you {}`.",
279                    ABI_ERR
280                ));
281                return;
282            }
283        }
284
285        let named_bound_lts: Vec<&'a Ident> = func
286            .lifetimes
287            .take() // Option<BoundLifetimes>
288            .into_iter()
289            .flat_map(|lt| lt.lifetimes)
290            .map(|lt| arenas.alloc(lt.lifetime.ident))
291            .collect::<Vec<&'a Ident>>();
292
293        let named_bound_lt_set = named_bound_lts.iter().cloned().collect();
294
295        let first_bound_lt = self.vars.fn_info.initial_bound_lifetime;
296        let bound_lts_count = named_bound_lts.len();
297        let mut current_function = FnVisitor {
298            refs: self.refs,
299            vars: &mut self.vars,
300            lifetime_counts: LifetimeCounters::new(),
301            current: Function {
302                fn_token: func.fn_token,
303                func_span: Ignored::new(func_span),
304                first_bound_lt,
305                first_unnamed_bound_lt: first_bound_lt + named_bound_lts.len(),
306                bound_lts_count,
307                named_bound_lts,
308                named_bound_lt_set: Ignored::new(named_bound_lt_set),
309                bound_lt_spans: Ignored::new(vec![None; bound_lts_count]),
310                is_unsafe,
311                params: Vec::new(),
312                returns: None,
313            },
314            param_ret: FnParamRetLifetimes::new(ParamOrReturn::Param, None),
315        };
316
317        // Visits a parameter or return type within a function pointer type.
318        fn visit_param_ret<'a, 'b>(
319            this: &mut FnVisitor<'a, 'b>,
320            name: Option<&'a Ident>,
321            ty: &'a mut Type,
322            param_or_ret: ParamOrReturn,
323        ) {
324            let ty_span = Some(ty.span());
325
326            this.param_ret = FnParamRetLifetimes::new(param_or_ret, ty_span);
327
328            this.visit_type_mut(ty);
329
330            let param_ret = mem::replace(
331                &mut this.param_ret,
332                FnParamRetLifetimes::new(param_or_ret, ty_span),
333            );
334
335            let param_ret = FnParamRet {
336                name,
337                lifetime_refs: param_ret.lifetime_refs,
338                ty,
339                param_or_ret: param_ret.param_or_ret,
340            };
341
342            match param_or_ret {
343                ParamOrReturn::Param => this.current.params.push(param_ret),
344                ParamOrReturn::Return => this.current.returns = Some(param_ret),
345            }
346        }
347
348        for (i, param) in func.inputs.iter_mut().enumerate() {
349            let arg_name = extract_fn_arg_name(i, param, arenas);
350            let ty = arenas.alloc_mut(param.ty.clone());
351            visit_param_ret(&mut current_function, arg_name, ty, ParamOrReturn::Param);
352        }
353
354        let tmp = match mem::replace(&mut func.output, syn::ReturnType::Default) {
355            syn::ReturnType::Default => None,
356            syn::ReturnType::Type(_, ty) => Some(arenas.alloc_mut(*ty)),
357        };
358        if let Some(ty) = tmp {
359            visit_param_ret(&mut current_function, None, ty, ParamOrReturn::Return);
360        }
361
362        let mut current = current_function.current;
363        current.anonimize_lifetimes(&current_function.lifetime_counts, &mut self.vars.errors);
364        while func.inputs.pop().is_some() {}
365        self.vars.fn_info.functions.push(current);
366    }
367
368    /// Visits a lifetime within a field type,
369    /// pushing it to the list of referenced lifetimes.
370    #[inline(never)]
371    fn visit_lifetime_mut(&mut self, lt: &mut Lifetime) {
372        let ctokens = self.refs.ctokens;
373        let lt = &lt.ident;
374        if *lt == ctokens.static_ {
375            LifetimeIndex::STATIC
376        } else {
377            let env_lifetimes = self.vars.fn_info.env_lifetimes.iter();
378            let found_lt = env_lifetimes
379                .enumerate()
380                .position(|(_, lt_ident)| *lt_ident == lt);
381            match found_lt {
382                Some(index) => LifetimeIndex::Param(index as _),
383                None => {
384                    self.vars
385                        .errors
386                        .push_err(spanned_err!(lt, "unknown lifetime"));
387                    LifetimeIndex::NONE
388                }
389            }
390        }
391        .piped(|lt| self.vars.add_referenced_env_lifetime(lt))
392    }
393
394    fn visit_type_macro_mut(&mut self, i: &mut syn::TypeMacro) {
395        if !self.vars.allow_type_macros {
396            push_type_macro_err(&mut self.vars.errors, i);
397        }
398    }
399}
400
401/////////////
402
403impl<'a, 'b> FnVisitor<'a, 'b> {
404    /// This function does these things:
405    ///
406    /// - Adds the lifetime as a referenced lifetime.
407    ///
408    /// - If `lt` is `Some('someident)` returns `Some('_)`.
409    ///
410    #[inline(never)]
411    fn setup_lifetime(&mut self, lt: Option<&Ident>, span: Span) -> Option<&'a Ident> {
412        let ctokens = self.refs.ctokens;
413        let mut ret: Option<&'a Ident> = None;
414        if lt == Some(&ctokens.static_) {
415            LifetimeIndex::STATIC
416        } else if lt.is_none() || lt == Some(&ctokens.underscore) {
417            match self.param_ret.param_or_ret {
418                ParamOrReturn::Param => self.new_bound_lifetime(span),
419                ParamOrReturn::Return => match self.current.bound_lts_count {
420                    0 => {
421                        self.vars.errors.push_err(syn_err!(
422                            span,
423                            "attempted to use an elided lifetime  in the \
424                                 return type when there are no lifetimes \
425                                 used in any parameter",
426                        ));
427                        LifetimeIndex::NONE
428                    }
429                    1 => LifetimeIndex::Param(self.vars.fn_info.initial_bound_lifetime as _),
430                    _ => {
431                        self.vars.errors.push_err(syn_err!(
432                            span,
433                            "attempted to use an elided lifetime in the \
434                                 return type when there are multiple lifetimes used \
435                                 in parameters.",
436                        ));
437                        LifetimeIndex::NONE
438                    }
439                },
440            }
441        } else {
442            let lt = lt.expect("BUG");
443            let env_lts = self.vars.fn_info.env_lifetimes.iter();
444            let fn_lts = self.current.named_bound_lts.iter();
445            let found_lt = env_lts.chain(fn_lts).position(|ident| *ident == lt);
446            match found_lt {
447                Some(index) => {
448                    if let Some(index) = index.checked_sub(self.current.first_bound_lt) {
449                        self.current.bound_lt_spans[index].get_or_insert(span);
450                    }
451                    ret = Some(&ctokens.underscore);
452                    LifetimeIndex::Param(index as _)
453                }
454                None => {
455                    self.vars
456                        .errors
457                        .push_err(spanned_err!(lt, "unknown lifetime"));
458                    LifetimeIndex::NONE
459                }
460            }
461        }
462        .piped(|li| {
463            self.param_ret.lifetime_refs.push(li);
464            self.lifetime_counts.increment(li);
465        });
466        ret
467    }
468
469    /// Adds a bound lifetime to the `extern "C" fn()` and returns an index to it
470    fn new_bound_lifetime(&mut self, span: Span) -> LifetimeIndex {
471        let index = self.vars.fn_info.initial_bound_lifetime + self.current.bound_lts_count;
472        self.current.bound_lt_spans.push(Some(span));
473        self.current.bound_lts_count += 1;
474        LifetimeIndex::Param(index as _)
475    }
476}
477
478impl<'a, 'b> VisitMut for FnVisitor<'a, 'b> {
479    #[inline(never)]
480    fn visit_type_bare_fn_mut(&mut self, func: &mut TypeBareFn) {
481        self.vars.errors.push_err(syn_err!(
482            self.param_ret.span,
483            "\n\
484             This library does not currently support nested function pointers.\n\
485             To use the function pointer as a parameter define a wrapper type:\n\
486             \t#[derive(StableAbi)]\n\
487             \t#[repr(transparent)] \n\
488             \tpub struct CallbackParam{{   \n\
489             \t\tpub func:{func}\n\
490             \t}}\n\
491             \n\
492             ",
493            func = func.to_token_stream()
494        ))
495    }
496
497    /// Visits references inside the function pointer type,
498    /// uneliding their lifetime parameter,
499    /// and pushing the lifetime to the list of lifetime indices.
500    fn visit_type_reference_mut(&mut self, ref_: &mut TypeReference) {
501        let _ctokens = self.refs.ctokens;
502        let lt = ref_.lifetime.as_ref().map(|x| &x.ident);
503        if let Some(ident) = self.setup_lifetime(lt, ref_.and_token.span()).cloned() {
504            if let Some(lt) = &mut ref_.lifetime {
505                lt.ident = ident
506            }
507        }
508
509        // Visits the `Foo` type in a `&'a Foo`.
510        visit_mut::visit_type_mut(self, &mut ref_.elem)
511    }
512
513    /// Visits a lifetime inside the function pointer type,
514    /// and pushing the lifetime to the list of lifetime indices.
515    fn visit_lifetime_mut(&mut self, lt: &mut Lifetime) {
516        if let Some(ident) = self.setup_lifetime(Some(&lt.ident), lt.apostrophe.span()) {
517            lt.ident = ident.clone();
518        }
519    }
520
521    fn visit_type_macro_mut(&mut self, i: &mut syn::TypeMacro) {
522        if !self.vars.allow_type_macros {
523            push_type_macro_err(&mut self.vars.errors, i);
524        }
525    }
526}
527
528/////////////
529
530fn extract_fn_arg_name<'a>(
531    _index: usize,
532    arg: &mut syn::BareFnArg,
533    arenas: &'a Arenas,
534) -> Option<&'a Ident> {
535    match arg.name.take() {
536        Some((name, _)) => Some(arenas.alloc(name)),
537        None => None,
538    }
539}
540
541/////////////
542
543impl<'a> Function<'a> {
544    /// Turns lifetimes in the function parameters that aren't
545    /// used in the return type or used only once into LifeimeIndex::ANONYMOUS,
546    fn anonimize_lifetimes(
547        &mut self,
548        lifetime_counts: &LifetimeCounters,
549        errors: &mut Result<(), syn::Error>,
550    ) {
551        let first_bound_lt = self.first_bound_lt;
552
553        let mut current_lt = first_bound_lt;
554
555        let asigned_lts = (0..self.bound_lts_count)
556            .map(|i| {
557                let lt_i: usize = first_bound_lt + i;
558
559                if lifetime_counts.get(LifetimeIndex::Param(lt_i)) <= 1 {
560                    LifetimeIndex::ANONYMOUS
561                } else {
562                    if current_lt == LifetimeIndex::MAX_LIFETIME_PARAM + 1 {
563                        errors.push_err(syn_err!(
564                            self.bound_lt_spans[i].unwrap_or(*self.func_span),
565                            "Cannot have more than {} non-static lifetimes \
566                             (except for lifetimes only used once inside \
567                             function pointer types)",
568                            LifetimeIndex::MAX_LIFETIME_PARAM + 1
569                        ));
570                    }
571
572                    let ret = LifetimeIndex::Param(current_lt);
573                    current_lt += 1;
574                    ret
575                }
576            })
577            .collect::<Vec<LifetimeIndex>>();
578
579        for params in &mut self.params {
580            for p_lt in &mut params.lifetime_refs {
581                let param = match p_lt.to_param() {
582                    Some(param) => (param).wrapping_sub(first_bound_lt),
583                    None => continue,
584                };
585
586                if let Some(assigned) = asigned_lts.get(param) {
587                    *p_lt = *assigned;
588                }
589            }
590        }
591    }
592}
593
594fn push_type_macro_err(res: &mut Result<(), syn::Error>, i: &syn::TypeMacro) {
595    res.push_err(spanned_err!(
596        i,
597        "\
598Cannot currently use type macros safely.
599
600To enable use of type macros use the `#[sabi(unsafe_allow_type_macros)]` attribute.
601
602The reason this is unsafe to enable them is because StableAbi cannot currently 
603analize the lifetimes within macros,
604which means that if any lifetime argument inside the macro invocation changes
605it won't be checked by the runtime type checker.
606
607"
608    ));
609}