abi_stable_derive/
export_root_module_impl.rs

1//! The implementation of the `#[export_root_module]` attribute.
2
3use super::*;
4
5use as_derive_utils::return_spanned_err;
6
7use syn::Ident;
8
9use proc_macro2::Span;
10
11use abi_stable_shared::mangled_root_module_loader_name;
12
13#[doc(hidden)]
14pub fn export_root_module_attr(_attr: TokenStream1, item: TokenStream1) -> TokenStream1 {
15    parse_or_compile_err(item, export_root_module_inner).into()
16}
17
18#[cfg(test)]
19fn export_root_module_str(item: &str) -> Result<TokenStream2, syn::Error> {
20    syn::parse_str(item).and_then(export_root_module_inner)
21}
22
23fn export_root_module_inner(mut input: ItemFn) -> Result<TokenStream2, syn::Error> {
24    let vis = &input.vis;
25
26    let unsafe_no_layout_constant_path =
27        syn::parse_str::<syn::Path>("unsafe_no_layout_constant").expect("BUG");
28
29    let mut found_unsafe_no_layout_constant = false;
30    input.attrs.retain(|attr| {
31        let is_it = attr.path == unsafe_no_layout_constant_path;
32        found_unsafe_no_layout_constant = found_unsafe_no_layout_constant || is_it;
33        !is_it
34    });
35    let check_ty_layout_variant = Ident::new(
36        if found_unsafe_no_layout_constant {
37            "No"
38        } else {
39            "Yes"
40        },
41        Span::call_site(),
42    );
43
44    let ret_ty = match &input.sig.output {
45        syn::ReturnType::Default => {
46            return_spanned_err!(input.sig.ident, "The return type can't be `()`")
47        }
48        syn::ReturnType::Type(_, ty) => ty,
49    };
50
51    let original_fn_ident = &input.sig.ident;
52
53    let export_name = Ident::new(&mangled_root_module_loader_name(), Span::call_site());
54
55    Ok(quote!(
56        #input
57
58        #[no_mangle]
59        #vis static #export_name: ::abi_stable::library::LibHeader = {
60
61            pub extern "C" fn _sabi_erased_module()-> ::abi_stable::library::RootModuleResult {
62                ::abi_stable::library::__call_root_module_loader(#original_fn_ident)
63            }
64
65            type __SABI_Module = <#ret_ty as ::abi_stable::library::IntoRootModuleResult>::Module;
66            unsafe{
67                ::abi_stable::library::LibHeader::from_constructor::<__SABI_Module>(
68                    _sabi_erased_module,
69                    ::abi_stable::library::CheckTypeLayout::#check_ty_layout_variant,
70                )
71            }
72        };
73    ))
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_output() {
82        let list = vec![
83            (
84                r##"
85                    pub fn hello()->RString{}
86                "##,
87                "CheckTypeLayout::Yes",
88            ),
89            (
90                r##"
91                    #[unsafe_no_layout_constant]
92                    pub fn hello()->RString{}
93                "##,
94                "CheckTypeLayout::No",
95            ),
96            (
97                r##"
98                    #[hello]
99                    #[unsafe_no_layout_constant]
100                    pub fn hello()->RString{}
101                "##,
102                "CheckTypeLayout::No",
103            ),
104            (
105                r##"
106                    #[hello]
107                    #[unsafe_no_layout_constant]
108                    #[hello]
109                    pub fn hello()->RString{}
110                "##,
111                "CheckTypeLayout::No",
112            ),
113            (
114                r##"
115                    #[unsafe_no_layout_constant]
116                    #[hello]
117                    pub fn hello()->RString{}
118                "##,
119                "CheckTypeLayout::No",
120            ),
121        ];
122
123        for (item, expected_const) in list {
124            let str_out = export_root_module_str(item)
125                .unwrap()
126                .to_string()
127                .chars()
128                .filter(|c| !c.is_whitespace())
129                .collect::<String>();
130            assert!(str_out.contains(expected_const));
131        }
132    }
133}