openrr_internal_codegen/
rpc.rs
1use std::path::Path;
4
5use anyhow::Result;
6use fs_err as fs;
7use heck::ToSnakeCase;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11 visit_mut::{self, VisitMut},
12 Ident, ItemTrait,
13};
14
15use super::*;
16
17pub(crate) fn gen(workspace_root: &Path) -> Result<()> {
18 const FULLY_IGNORE: &[&str] = &["SetCompleteCondition"];
19 const IGNORE: &[&str] = &["JointTrajectoryClient", "SetCompleteCondition", "Gamepad"];
20
21 let out_dir = &workspace_root.join("openrr-remote/src/gen");
22 fs::create_dir_all(out_dir)?;
23 let mut items = TokenStream::new();
24 let mut traits = vec![];
25
26 let mut pb_traits = vec![];
27 let pb_file = fs::read_to_string(workspace_root.join("openrr-remote/src/generated/arci.rs"))?;
28 CollectTrait(&mut pb_traits).visit_file_mut(&mut syn::parse_file(&pb_file)?);
29
30 let (arci_traits, _arci_structs, _arci_enums) = arci_types(workspace_root)?;
31 for item in arci_traits {
32 let name = &&*item.ident.to_string();
33 if FULLY_IGNORE.contains(name) {
34 continue;
35 }
36 traits.push(item.ident.clone());
37
38 let trait_name = &item.ident;
39 items.extend(gen_remote_types(trait_name));
40
41 if IGNORE.contains(name) {
42 continue;
43 }
44
45 items.extend(gen_client_impl(trait_name, &item));
46 items.extend(gen_server_impl(trait_name, &item, &pb_traits));
47 }
48
49 let items = quote! {
50 use arci::{
51 BaseVelocity,
52 Error,
53 Isometry2,
54 Isometry3,
55 Scan2D,
56 WaitFuture,
57 };
58 use super::*;
59 #items
60 };
61
62 write(&out_dir.join("impls.rs"), items)?;
63 Ok(())
64}
65
66fn gen_remote_types(trait_name: &Ident) -> TokenStream {
67 let client_name = format_ident!("Remote{trait_name}Sender");
68 let client_pb_ty = format_ident!("{trait_name}Client");
69 let client_pb_mod = format_ident!("{}_client", trait_name.to_string().to_snake_case());
70 let server_name = format_ident!("Remote{trait_name}Receiver");
71 let server_pb_ty = format_ident!("{trait_name}Server");
72 let server_pb_mod = format_ident!("{}_server", trait_name.to_string().to_snake_case());
73
74 quote! {
75 #[derive(Debug, Clone)]
76 pub struct #client_name {
77 pub(crate) client: pb::#client_pb_mod::#client_pb_ty<tonic::transport::Channel>,
78 }
79
80 impl #client_name {
81 pub async fn connect<D>(dst: D) -> Result<Self, arci::Error>
83 where
84 D: TryInto<tonic::transport::Endpoint>,
85 D::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
86 {
87 let client = pb::#client_pb_mod::#client_pb_ty::connect(dst)
88 .await
89 .map_err(|e| arci::Error::Connection {
90 message: e.to_string(),
91 })?;
92 Ok(Self { client })
93 }
94
95 pub fn new(channel: tonic::transport::Channel) -> Self {
97 Self {
98 client: pb::#client_pb_mod::#client_pb_ty::new(channel),
99 }
100 }
101 }
102
103 #[derive(Debug)]
104 pub struct #server_name<T> {
105 pub(crate) inner: T,
106 }
107
108 impl<T> #server_name<T>
109 where
110 T: arci::#trait_name + 'static,
111 {
112 pub fn new(inner: T) -> Self {
114 Self { inner }
115 }
116
117 pub fn into_service(self) -> pb::#server_pb_mod::#server_pb_ty<Self> {
119 pb::#server_pb_mod::#server_pb_ty::new(self)
120 }
121
122 pub async fn serve(self, addr: SocketAddr) -> Result<(), arci::Error> {
123 tonic::transport::Server::builder()
124 .add_service(self.into_service())
125 .serve(addr)
126 .await
127 .map_err(|e| arci::Error::Connection {
128 message: e.to_string(),
129 })?;
130 Ok(())
131 }
132 }
133 }
134}
135
136fn gen_client_impl(trait_name: &Ident, item: &ItemTrait) -> TokenStream {
137 let client_name = format_ident!("Remote{trait_name}Sender");
138
139 let methods = item.items.iter().map(|method| match method {
140 syn::TraitItem::Fn(method) => {
141 let sig = &method.sig;
142 let name = &sig.ident;
143 let args: Vec<_> = sig
144 .inputs
145 .iter()
146 .filter_map(|arg| match arg {
147 syn::FnArg::Receiver(_) => None,
148 syn::FnArg::Typed(arg) => {
149 let pat = &arg.pat;
150 Some(
151 if matches!(&*arg.ty, syn::Type::Reference(..)) && !is_str(&arg.ty) {
152 quote! { (*#pat) }
153 } else {
154 quote! { #pat }
155 },
156 )
157 }
158 })
159 .collect();
160 let args = match args.len() {
161 0 => quote! { () },
162 1 => quote! { #(#args)*.into() },
163 _ => quote! { (#(#args),*).into() },
164 };
165 let call = match &sig.output {
166 syn::ReturnType::Type(_, ty) => {
167 let path = get_ty_path(is_result(ty).unwrap());
168 if path.is_some_and(|p| p.segments.last().unwrap().ident == "WaitFuture") {
169 quote! {
170 Ok(wait_from_handle(tokio::spawn(async move {
171 client.#name(args).await
172 })))
173 }
174 } else {
175 quote! {
176 Ok(block_in_place(client.#name(args))
177 .map_err(|e| arci::Error::Other(e.into()))?
178 .into_inner()
179 .into())
180 }
181 }
182 }
183 syn::ReturnType::Default => unreachable!(),
184 };
185 quote! {
186 #sig {
187 let mut client = self.client.clone();
188 let args = tonic::Request::new(#args);
189 #call
190 }
191 }
192 }
193 _ => quote! {},
194 });
195
196 quote! {
197 impl arci::#trait_name for #client_name {
198 #(#methods)*
199 }
200 }
201}
202
203fn gen_server_impl(trait_name: &Ident, item: &ItemTrait, pb_traits: &[ItemTrait]) -> TokenStream {
204 const USE_TRY_INTO: &[&str] = &["SystemTime", "Duration"];
205
206 let server_name = format_ident!("Remote{trait_name}Receiver");
207 let server_pb_mod = format_ident!("{}_server", trait_name.to_string().to_snake_case());
208 let pb_trait = pb_traits.iter().find(|t| t.ident == *trait_name).unwrap();
209
210 let methods = item.items.iter().map(|method| match method {
211 syn::TraitItem::Fn(method) => {
212 struct ReplacePath;
213 impl VisitMut for ReplacePath {
214 fn visit_path_mut(&mut self, path: &mut syn::Path) {
215 if path.segments[0].ident == "super" {
216 path.segments[0].ident = format_ident!("pb");
217 }
218 visit_mut::visit_path_mut(self, path);
219 }
220 }
221
222 let name = &method.sig.ident;
223 let arg_len = method.sig.inputs.len() - 1;
224 let args: Vec<_> = method
225 .sig
226 .inputs
227 .iter()
228 .filter_map(|arg| match arg {
229 syn::FnArg::Receiver(_) => None,
230 syn::FnArg::Typed(arg) => {
231 let pat = &arg.pat;
232 let mut into = quote! { .into() };
233 if let Some(path) = get_ty_path(&arg.ty) {
234 if USE_TRY_INTO
235 .contains(&&*path.segments.last().unwrap().ident.to_string())
236 {
237 into = quote! { .try_into().unwrap() }
238 }
239 }
240 Some(match arg_len {
241 0 => unreachable!(),
242 1 => {
243 if is_str(&arg.ty) {
244 quote! { &request }
245 } else if matches!(&*arg.ty, syn::Type::Reference(..)) {
246 quote! { &request #into }
247 } else {
248 quote! { request #into }
249 }
250 }
251 _ => {
252 if is_str(&arg.ty) {
253 quote! { &request.#pat }
254 } else if matches!(&*arg.ty, syn::Type::Reference(..)) {
255 quote! { &request.#pat.unwrap()#into }
256 } else {
257 quote! { request.#pat.unwrap()#into }
258 }
259 }
260 })
261 }
262 })
263 .collect();
264 let mut pb_method = pb_trait
265 .items
266 .iter()
267 .find_map(|m| {
268 if let syn::TraitItem::Fn(m) = m {
269 if m.sig.ident == *name {
270 return Some(m.clone());
271 }
272 }
273 None
274 })
275 .unwrap();
276 ReplacePath.visit_signature_mut(&mut pb_method.sig);
277 let sig = &pb_method.sig;
278 let call = match &method.sig.output {
279 syn::ReturnType::Type(_, ty) => {
280 let path = get_ty_path(is_result(ty).unwrap());
281 if path.is_some_and(|p| p.segments.last().unwrap().ident == "WaitFuture") {
282 quote! {
283 let res = arci::#trait_name::#name(&self.inner, #(#args),*)
284 .map_err(|e| tonic::Status::unknown(e.to_string()))?
285 .await
286 .map_err(|e| tonic::Status::unknown(e.to_string()))?
287 .into();
288 }
289 } else {
290 quote! {
291 let res = arci::#trait_name::#name(&self.inner, #(#args),*)
292 .map_err(|e| tonic::Status::unknown(e.to_string()))?
293 .into();
294 }
295 }
296 }
297 syn::ReturnType::Default => unreachable!(),
298 };
299 quote! {
300 #sig {
301 let request = request.into_inner();
302 #call
303 Ok(tonic::Response::new(res))
304 }
305 }
306 }
307 _ => quote! {},
308 });
309
310 quote! {
311 #[tonic::async_trait]
312 impl<T> pb::#server_pb_mod::#trait_name for #server_name<T>
313 where
314 T: arci::#trait_name + 'static,
315 {
316 #(#methods)*
317 }
318 }
319}
320
321struct CollectTrait<'a>(&'a mut Vec<ItemTrait>);
322
323impl VisitMut for CollectTrait<'_> {
324 fn visit_item_trait_mut(&mut self, i: &mut ItemTrait) {
325 self.0.push(i.clone());
326 }
327}