openrr_internal_codegen/
rpc.rs

1//! Codegen for openrr-remote
2
3use std::path::Path;
4
5use anyhow::Result;
6use fs_err as fs;
7use heck::ToSnakeCase;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11    visit_mut::{self, VisitMut},
12    Ident, ItemTrait,
13};
14
15use super::*;
16
17pub(crate) fn gen(workspace_root: &Path) -> Result<()> {
18    const FULLY_IGNORE: &[&str] = &["SetCompleteCondition"];
19    const IGNORE: &[&str] = &["JointTrajectoryClient", "SetCompleteCondition", "Gamepad"];
20
21    let out_dir = &workspace_root.join("openrr-remote/src/gen");
22    fs::create_dir_all(out_dir)?;
23    let mut items = TokenStream::new();
24    let mut traits = vec![];
25
26    let mut pb_traits = vec![];
27    let pb_file = fs::read_to_string(workspace_root.join("openrr-remote/src/generated/arci.rs"))?;
28    CollectTrait(&mut pb_traits).visit_file_mut(&mut syn::parse_file(&pb_file)?);
29
30    let (arci_traits, _arci_structs, _arci_enums) = arci_types(workspace_root)?;
31    for item in arci_traits {
32        let name = &&*item.ident.to_string();
33        if FULLY_IGNORE.contains(name) {
34            continue;
35        }
36        traits.push(item.ident.clone());
37
38        let trait_name = &item.ident;
39        items.extend(gen_remote_types(trait_name));
40
41        if IGNORE.contains(name) {
42            continue;
43        }
44
45        items.extend(gen_client_impl(trait_name, &item));
46        items.extend(gen_server_impl(trait_name, &item, &pb_traits));
47    }
48
49    let items = quote! {
50        use arci::{
51            BaseVelocity,
52            Error,
53            Isometry2,
54            Isometry3,
55            Scan2D,
56            WaitFuture,
57        };
58        use super::*;
59        #items
60    };
61
62    write(&out_dir.join("impls.rs"), items)?;
63    Ok(())
64}
65
66fn gen_remote_types(trait_name: &Ident) -> TokenStream {
67    let client_name = format_ident!("Remote{trait_name}Sender");
68    let client_pb_ty = format_ident!("{trait_name}Client");
69    let client_pb_mod = format_ident!("{}_client", trait_name.to_string().to_snake_case());
70    let server_name = format_ident!("Remote{trait_name}Receiver");
71    let server_pb_ty = format_ident!("{trait_name}Server");
72    let server_pb_mod = format_ident!("{}_server", trait_name.to_string().to_snake_case());
73
74    quote! {
75        #[derive(Debug, Clone)]
76        pub struct #client_name {
77            pub(crate) client: pb::#client_pb_mod::#client_pb_ty<tonic::transport::Channel>,
78        }
79
80        impl #client_name {
81            /// Attempt to create a new sender by connecting to a given endpoint.
82            pub async fn connect<D>(dst: D) -> Result<Self, arci::Error>
83            where
84                D: TryInto<tonic::transport::Endpoint>,
85                D::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
86            {
87                let client = pb::#client_pb_mod::#client_pb_ty::connect(dst)
88                    .await
89                    .map_err(|e| arci::Error::Connection {
90                        message: e.to_string(),
91                    })?;
92                Ok(Self { client })
93            }
94
95            /// Create a new sender.
96            pub fn new(channel: tonic::transport::Channel) -> Self {
97                Self {
98                    client: pb::#client_pb_mod::#client_pb_ty::new(channel),
99                }
100            }
101        }
102
103        #[derive(Debug)]
104        pub struct #server_name<T> {
105            pub(crate) inner: T,
106        }
107
108        impl<T> #server_name<T>
109        where
110            T: arci::#trait_name + 'static,
111        {
112            /// Create a new receiver.
113            pub fn new(inner: T) -> Self {
114                Self { inner }
115            }
116
117            /// Convert this receiver into a tower service.
118            pub fn into_service(self) -> pb::#server_pb_mod::#server_pb_ty<Self> {
119                pb::#server_pb_mod::#server_pb_ty::new(self)
120            }
121
122            pub async fn serve(self, addr: SocketAddr) -> Result<(), arci::Error> {
123                tonic::transport::Server::builder()
124                    .add_service(self.into_service())
125                    .serve(addr)
126                    .await
127                    .map_err(|e| arci::Error::Connection {
128                        message: e.to_string(),
129                    })?;
130                Ok(())
131            }
132        }
133    }
134}
135
136fn gen_client_impl(trait_name: &Ident, item: &ItemTrait) -> TokenStream {
137    let client_name = format_ident!("Remote{trait_name}Sender");
138
139    let methods = item.items.iter().map(|method| match method {
140        syn::TraitItem::Fn(method) => {
141            let sig = &method.sig;
142            let name = &sig.ident;
143            let args: Vec<_> = sig
144                .inputs
145                .iter()
146                .filter_map(|arg| match arg {
147                    syn::FnArg::Receiver(_) => None,
148                    syn::FnArg::Typed(arg) => {
149                        let pat = &arg.pat;
150                        Some(
151                            if matches!(&*arg.ty, syn::Type::Reference(..)) && !is_str(&arg.ty) {
152                                quote! { (*#pat) }
153                            } else {
154                                quote! { #pat }
155                            },
156                        )
157                    }
158                })
159                .collect();
160            let args = match args.len() {
161                0 => quote! { () },
162                1 => quote! { #(#args)*.into() },
163                _ => quote! { (#(#args),*).into() },
164            };
165            let call = match &sig.output {
166                syn::ReturnType::Type(_, ty) => {
167                    let path = get_ty_path(is_result(ty).unwrap());
168                    if path.is_some_and(|p| p.segments.last().unwrap().ident == "WaitFuture") {
169                        quote! {
170                            Ok(wait_from_handle(tokio::spawn(async move {
171                                client.#name(args).await
172                            })))
173                        }
174                    } else {
175                        quote! {
176                            Ok(block_in_place(client.#name(args))
177                                .map_err(|e| arci::Error::Other(e.into()))?
178                                .into_inner()
179                                .into())
180                        }
181                    }
182                }
183                syn::ReturnType::Default => unreachable!(),
184            };
185            quote! {
186                #sig {
187                    let mut client = self.client.clone();
188                    let args = tonic::Request::new(#args);
189                    #call
190                }
191            }
192        }
193        _ => quote! {},
194    });
195
196    quote! {
197        impl arci::#trait_name for #client_name {
198            #(#methods)*
199        }
200    }
201}
202
203fn gen_server_impl(trait_name: &Ident, item: &ItemTrait, pb_traits: &[ItemTrait]) -> TokenStream {
204    const USE_TRY_INTO: &[&str] = &["SystemTime", "Duration"];
205
206    let server_name = format_ident!("Remote{trait_name}Receiver");
207    let server_pb_mod = format_ident!("{}_server", trait_name.to_string().to_snake_case());
208    let pb_trait = pb_traits.iter().find(|t| t.ident == *trait_name).unwrap();
209
210    let methods = item.items.iter().map(|method| match method {
211        syn::TraitItem::Fn(method) => {
212            struct ReplacePath;
213            impl VisitMut for ReplacePath {
214                fn visit_path_mut(&mut self, path: &mut syn::Path) {
215                    if path.segments[0].ident == "super" {
216                        path.segments[0].ident = format_ident!("pb");
217                    }
218                    visit_mut::visit_path_mut(self, path);
219                }
220            }
221
222            let name = &method.sig.ident;
223            let arg_len = method.sig.inputs.len() - 1;
224            let args: Vec<_> = method
225                .sig
226                .inputs
227                .iter()
228                .filter_map(|arg| match arg {
229                    syn::FnArg::Receiver(_) => None,
230                    syn::FnArg::Typed(arg) => {
231                        let pat = &arg.pat;
232                        let mut into = quote! { .into() };
233                        if let Some(path) = get_ty_path(&arg.ty) {
234                            if USE_TRY_INTO
235                                .contains(&&*path.segments.last().unwrap().ident.to_string())
236                            {
237                                into = quote! { .try_into().unwrap() }
238                            }
239                        }
240                        Some(match arg_len {
241                            0 => unreachable!(),
242                            1 => {
243                                if is_str(&arg.ty) {
244                                    quote! { &request }
245                                } else if matches!(&*arg.ty, syn::Type::Reference(..)) {
246                                    quote! { &request #into }
247                                } else {
248                                    quote! { request #into }
249                                }
250                            }
251                            _ => {
252                                if is_str(&arg.ty) {
253                                    quote! { &request.#pat }
254                                } else if matches!(&*arg.ty, syn::Type::Reference(..)) {
255                                    quote! { &request.#pat.unwrap()#into }
256                                } else {
257                                    quote! { request.#pat.unwrap()#into }
258                                }
259                            }
260                        })
261                    }
262                })
263                .collect();
264            let mut pb_method = pb_trait
265                .items
266                .iter()
267                .find_map(|m| {
268                    if let syn::TraitItem::Fn(m) = m {
269                        if m.sig.ident == *name {
270                            return Some(m.clone());
271                        }
272                    }
273                    None
274                })
275                .unwrap();
276            ReplacePath.visit_signature_mut(&mut pb_method.sig);
277            let sig = &pb_method.sig;
278            let call = match &method.sig.output {
279                syn::ReturnType::Type(_, ty) => {
280                    let path = get_ty_path(is_result(ty).unwrap());
281                    if path.is_some_and(|p| p.segments.last().unwrap().ident == "WaitFuture") {
282                        quote! {
283                            let res = arci::#trait_name::#name(&self.inner, #(#args),*)
284                                .map_err(|e| tonic::Status::unknown(e.to_string()))?
285                                .await
286                                .map_err(|e| tonic::Status::unknown(e.to_string()))?
287                                .into();
288                        }
289                    } else {
290                        quote! {
291                            let res = arci::#trait_name::#name(&self.inner, #(#args),*)
292                                .map_err(|e| tonic::Status::unknown(e.to_string()))?
293                                .into();
294                        }
295                    }
296                }
297                syn::ReturnType::Default => unreachable!(),
298            };
299            quote! {
300                #sig {
301                    let request = request.into_inner();
302                    #call
303                    Ok(tonic::Response::new(res))
304                }
305            }
306        }
307        _ => quote! {},
308    });
309
310    quote! {
311        #[tonic::async_trait]
312        impl<T> pb::#server_pb_mod::#trait_name for #server_name<T>
313        where
314            T: arci::#trait_name + 'static,
315        {
316            #(#methods)*
317        }
318    }
319}
320
321struct CollectTrait<'a>(&'a mut Vec<ItemTrait>);
322
323impl VisitMut for CollectTrait<'_> {
324    fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
325        self.0.push(i.clone());
326    }
327}