abi_stable_derive/stable_abi/
tl_function.rs

1//! Contains types related to the type layout of function pointers.
2
3use super::*;
4
5use crate::{
6    composite_collections::SmallStartLen as StartLen, fn_pointer_extractor::Function,
7    fn_pointer_extractor::TypeVisitor, lifetimes::LifetimeRange,
8};
9
10use std::marker::PhantomData;
11
12use syn::Type;
13
14///////////////////////////////////////////////////////////////////////////////
15
16/// Associates extra information related to function pointers to a type declaration.
17#[allow(dead_code)]
18pub(crate) struct VisitedFieldMap<'a> {
19    pub(crate) map: Vec<VisitedField<'a>>,
20    pub(crate) fn_ptr_count: usize,
21    priv_: (),
22}
23
24impl<'a> VisitedFieldMap<'a> {
25    pub(crate) fn new(
26        ds: &'a DataStructure<'a>,
27        config: &'a StableAbiOptions<'a>,
28        shared_vars: &mut SharedVars<'a>,
29        ctokens: &'a CommonTokens<'a>,
30    ) -> Self {
31        let arenas = shared_vars.arenas();
32        let mut tv = TypeVisitor::new(arenas, ctokens.as_ref(), ds.generics);
33        if config.allow_type_macros {
34            tv.allow_type_macros();
35        }
36
37        let mut fn_ptr_count = 0;
38
39        let map = ds
40            .variants
41            .iter()
42            .flat_map(|x| &x.fields)
43            .map(|field| {
44                // The type used to get the TypeLayout of the field.
45                // This has all parameter and return types of function pointers removed.
46                // Extracted into the `functions` field of this struct.
47                let mut mutated_ty = config.changed_types[field].unwrap_or(field.ty).clone();
48                let layout_ctor = config.layout_ctor[field];
49                let is_opaque = layout_ctor.is_opaque();
50
51                let is_function = match mutated_ty {
52                    Type::BareFn { .. } => !is_opaque,
53                    _ => false,
54                };
55
56                let visit_info = tv.visit_field(&mut mutated_ty);
57
58                let mutated_ty = arenas.alloc(mutated_ty);
59
60                let field_accessor = config.override_field_accessor[field]
61                    .unwrap_or_else(|| config.kind.field_accessor(config.mod_refl_mode, field));
62
63                let name = config.renamed_fields[field].unwrap_or_else(|| field.pat_ident());
64
65                let comp_field = CompTLField::from_expanded(
66                    name,
67                    visit_info.referenced_lifetimes.iter().cloned(),
68                    field_accessor,
69                    shared_vars.push_type(layout_ctor, mutated_ty),
70                    is_function,
71                    shared_vars,
72                );
73
74                let iterated_functions = if is_opaque {
75                    Vec::new()
76                } else {
77                    visit_info.functions
78                };
79
80                let functions = iterated_functions
81                    .iter()
82                    .enumerate()
83                    .map(|(fn_i, func): (usize, &Function<'_>)| {
84                        let name_span = name.span();
85                        let name_start_len = if is_function || iterated_functions.len() == 1 {
86                            comp_field.name_start_len()
87                        } else {
88                            shared_vars.push_str(&format!("fn_{}", fn_i), Some(name_span))
89                        };
90
91                        shared_vars.combine_err(name_start_len.check_ident_length(name_span));
92
93                        let bound_lifetimes_start_len = shared_vars
94                            .extend_with_idents(",", func.named_bound_lts.iter().cloned());
95
96                        let params_iter = func.params.iter().map(|p| match p.name {
97                            Some(pname) => (pname as &dyn std::fmt::Display, pname.span()),
98                            None => (&"" as &dyn std::fmt::Display, Span::call_site()),
99                        });
100                        let param_names_len = shared_vars.extend_with_display(",", params_iter).len;
101
102                        let param_type_layouts =
103                            TypeLayoutRange::compress_params(&func.params, shared_vars);
104
105                        let paramret_lifetime_range = shared_vars.extend_with_lifetime_indices(
106                            func.params
107                                .iter()
108                                .chain(&func.returns)
109                                .flat_map(|p| p.lifetime_refs.iter().cloned()),
110                        );
111
112                        let return_type_layout = match &func.returns {
113                            Some(ret) => shared_vars.push_type(layout_ctor, ret.ty).to_u10(),
114                            None => !0,
115                        };
116
117                        CompTLFunction {
118                            name: name_start_len,
119                            contiguous_strings_offset: bound_lifetimes_start_len.start,
120                            bound_lifetimes_len: bound_lifetimes_start_len.len,
121                            param_names_len,
122                            param_type_layouts,
123                            paramret_lifetime_range,
124                            return_type_layout,
125                            is_unsafe: func.is_unsafe,
126                        }
127                    })
128                    .collect::<Vec<CompTLFunction>>();
129
130                fn_ptr_count += functions.len();
131
132                VisitedField {
133                    comp_field,
134                    layout_ctor,
135                    functions,
136                    _marker: PhantomData,
137                }
138            })
139            .collect::<Vec<VisitedField<'a>>>();
140
141        shared_vars.combine_err(tv.get_errors());
142
143        Self {
144            map,
145            fn_ptr_count,
146            priv_: (),
147        }
148    }
149}
150
151///////////////////////////////////////////////////////////////////////////////
152
153/// A `Field<'a>` with extra information.
154#[allow(dead_code)]
155pub struct VisitedField<'a> {
156    pub(crate) comp_field: CompTLField,
157    pub(crate) layout_ctor: LayoutConstructor,
158    /// The function pointers from this field.
159    pub(crate) functions: Vec<CompTLFunction>,
160    _marker: PhantomData<&'a ()>,
161}
162
163///////////////////////////////////////////////////////////////////////////////
164
165/// This is how a function pointer is stored,
166/// in which every field is a range into `TLFunctions`.
167#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)]
168pub struct CompTLFunction {
169    name: StartLen,
170    contiguous_strings_offset: u16,
171    bound_lifetimes_len: u16,
172    param_names_len: u16,
173    /// Stores `!0` if the return type is `()`.
174    return_type_layout: u16,
175    paramret_lifetime_range: LifetimeRange,
176    param_type_layouts: TypeLayoutRange,
177    is_unsafe: bool,
178}
179
180impl ToTokens for CompTLFunction {
181    fn to_tokens(&self, ts: &mut TokenStream2) {
182        let name = self.name.to_u32();
183
184        let contiguous_strings_offset = self.contiguous_strings_offset;
185        let bound_lifetimes_len = self.bound_lifetimes_len;
186        let param_names_len = self.param_names_len;
187        let return_type_layout = self.return_type_layout;
188        let paramret_lifetime_range = self.paramret_lifetime_range.to_u21();
189        let param_type_layouts = self.param_type_layouts.to_u64();
190        let is_unsafe = if self.is_unsafe {
191            quote!( .set_unsafe() )
192        } else {
193            TokenStream2::new()
194        };
195
196        quote!(
197            __CompTLFunction::new(
198                #name,
199                #contiguous_strings_offset,
200                #bound_lifetimes_len,
201                #param_names_len,
202                #return_type_layout,
203                #paramret_lifetime_range,
204                #param_type_layouts,
205                __TLFunctionQualifiers::NEW
206                    #is_unsafe,
207            )
208        )
209        .to_tokens(ts);
210    }
211}