1use std::{iter, mem};
2
3use anyhow::{bail, Context, Result};
4use toml::{value, Value};
5use toml_query::{delete::TomlValueDeleteExt, insert::TomlValueInsertExt, read::TomlValueReadExt};
6use tracing::debug;
7
8const SEPARATORS: &[char] = &['\n', ';'];
9
10pub fn overwrite(doc: &mut Value, scripts: &str) -> Result<()> {
39 let scripts = parse_scripts(scripts)?;
40
41 for script in scripts {
42 let query = &script.query;
43 let old = doc.read_mut(query)?;
44 let exists = old.is_some();
45 let is_structure = matches!(&old, Some(r) if r.is_table() || r.is_array());
46 match script.operation {
47 Operation::Set(value) => {
48 debug!(?query, ?value, "executing insert operation");
52 doc.insert(query, value)?;
53 }
54 Operation::Delete => {
55 if !exists {
56 debug!(
57 ?query,
58 "delete operation was not executed because value did not exist"
59 );
60 continue;
61 }
62
63 let old = old.unwrap();
64 if is_structure {
67 if old.is_array() {
68 *old = Value::Array(vec![]);
69 } else {
70 *old = Value::Table(value::Map::new());
71 }
72 }
73
74 debug!(?query, ?is_structure, ?old, "executing delete operation");
75 doc.delete(query)?;
76 }
77 }
78 }
79
80 Ok(())
81}
82
83pub fn overwrite_str(doc: &str, scripts: &str) -> Result<String> {
88 let mut doc: toml::Value = toml::from_str(doc)?;
89 overwrite(&mut doc, scripts)?;
90 Ok(toml::to_string(&doc)?)
91}
92
93#[derive(Debug)]
94struct Script {
95 query: String,
96 operation: Operation,
97}
98
99#[derive(Debug)]
100enum Operation {
101 Set(Value),
102 Delete,
103}
104
105fn parse_string_literal(
106 buf: &mut String,
107 start: char,
108 chars: &mut iter::Peekable<impl Iterator<Item = (usize, char)>>,
109) -> bool {
110 buf.push(start);
111 while let Some((_, ch)) = chars.next() {
112 buf.push(ch);
113 match ch {
114 '\\' => {
115 if matches!(chars.peek(), Some(&(_, c)) if c == start) {
116 buf.push(chars.next().unwrap().1);
117 continue;
118 }
119 }
120 _ if ch == start => {
121 return true;
122 }
123 _ => {}
124 }
125 }
126 false
127}
128
129fn parse_scripts(s: &str) -> Result<Vec<Script>> {
130 fn push_script(
131 cur_query: &mut Option<String>,
132 buf: &mut String,
133 scripts: &mut Vec<Script>,
134 i: usize,
135 ) -> Result<()> {
136 let query = cur_query.take().unwrap();
137 let value = mem::take(buf);
138 let value = value.trim();
139 let operation = if value.is_empty() {
140 Operation::Delete
141 } else {
142 let value: Value = toml::from_str(&format!(r#"a = {value}"#))
143 .with_context(|| format!("invalid script syntax at {}: {value}", i + 1))?;
144 Operation::Set(value["a"].clone())
145 };
146
147 scripts.push(Script {
148 query: convert_query(&query)?,
149 operation,
150 });
151 Ok(())
152 }
153
154 let mut scripts = vec![];
155
156 let mut chars = s.char_indices().peekable();
157 let mut cur_query = None;
158 let mut in_bracket = 0;
159 let mut in_brace = 0;
160 let mut buf = String::new();
161 while let Some((i, ch)) = chars.next() {
162 match ch {
163 '"' | '\'' => {
164 let end = parse_string_literal(&mut buf, ch, &mut chars);
165 if !end {
166 debug!(?buf, ?cur_query, "unexpected eof, expected `{ch}`");
167 bail!("unexpected eof, expected `{ch}`");
168 }
169 }
170 '[' => {
171 buf.push(ch);
172 in_bracket += 1;
173 }
174 ']' => {
175 buf.push(ch);
176 in_bracket -= 1;
177 }
178 '{' => {
179 buf.push(ch);
180 in_brace += 1;
181 }
182 '}' => {
183 buf.push(ch);
184 in_brace -= 1;
185 }
186 '=' if in_bracket <= 0 && in_brace <= 0 => {
187 if cur_query.is_some() {
188 debug!(?buf, ?i, "expected separator, found `=`");
189 bail!("expected separator, found `=`");
190 }
191 let query = mem::take(&mut buf);
192 cur_query = Some(query);
193 debug!(?cur_query);
194 }
195 _ if in_bracket <= 0 && in_brace <= 0 && SEPARATORS.contains(&ch) => {
196 if cur_query.is_none() {
197 if buf.trim().is_empty() {
198 buf.clear();
199 continue;
200 } else {
201 debug!(?buf, ?i, "expected `=`, found separator");
202 bail!("expected `=`, found separator");
203 }
204 }
205 push_script(&mut cur_query, &mut buf, &mut scripts, i)?;
206 }
207 _ if ch.is_whitespace() => {}
208 _ => {
209 buf.push(ch);
210 }
211 }
212 }
213 if cur_query.is_none() {
214 if !buf.trim().is_empty() {
215 debug!(?buf, ?cur_query, ?in_bracket, ?in_brace, "unexpected eof");
216 bail!("unexpected eof");
217 }
218 } else {
219 push_script(&mut cur_query, &mut buf, &mut scripts, s.len() - 1)?;
220 }
221
222 Ok(scripts)
223}
224
225fn convert_query(s: &str) -> Result<String> {
226 let mut out = String::with_capacity(s.len());
227 let mut chars = s.char_indices().peekable();
228 while let Some((_, ch)) = chars.next() {
229 match ch {
230 '"' | '\'' => {
231 let end = parse_string_literal(&mut out, ch, &mut chars);
232 assert!(end);
233 }
234 '[' => {
235 if !out.ends_with('.') {
236 out.push('.');
237 }
238 out.push(ch);
239 }
240 _ => out.push(ch),
241 }
242 }
243
244 Ok(out)
245}
246
247#[cfg(test)]
248mod tests {
249 use std::{fs, path::PathBuf};
250
251 use assert_approx_eq::assert_approx_eq;
252
253 use super::*;
254
255 #[test]
256 fn test_parse_scripts() {
257 let f = |s: &str| parse_scripts(s).unwrap();
258 assert!(f("").is_empty());
259 assert!(f("\n").is_empty());
260 assert!(f(";").is_empty());
261 assert!(f(";;").is_empty());
262
263 assert!(parse_scripts("a").is_err());
264 assert!(parse_scripts("a\n").is_err());
265 assert!(parse_scripts("a;").is_err());
266 assert!(parse_scripts(";a").is_err());
267 assert!(parse_scripts("a=b").is_err());
268 assert!(parse_scripts(r#"a="b"=0"#).is_err());
269 assert!(parse_scripts(r#"a=""""#).is_err());
270
271 assert!(matches!(
272 &f(r#"a="b""#)[0],
273 Script { query, operation: Operation::Set(Value::String(s)), .. }
274 if query == "a" && s == "b"
275 ));
276 assert!(matches!(
277 &f(r#"a.b="c=d""#)[0],
278 Script { query, operation: Operation::Set(Value::String(s)), .. }
279 if query == "a.b" && s == "c=d"
280 ));
281 assert!(matches!(
282 &f(r#"a="""#)[0],
283 Script { query, operation: Operation::Set(Value::String(s)), .. }
284 if query == "a" && s.is_empty()
285 ));
286 assert!(matches!(
287 &f("a=\"\\\"\"")[0],
288 Script { query, operation: Operation::Set(Value::String(s)), .. }
289 if query == "a" && s == "\""
290 ));
291 assert!(matches!(
292 &f("a=")[0],
293 Script { query, operation: Operation::Delete, .. }
294 if query == "a"
295 ));
296
297 assert!(matches!(
299 &f(r#"a[0]="""#)[0],
300 Script { query, operation: Operation::Set(Value::String(s)), .. }
301 if query == "a.[0]" && s.is_empty()
302 ));
303 assert!(matches!(
304 &f(r#"a.[0]="""#)[0],
305 Script { query, operation: Operation::Set(Value::String(s)), .. }
306 if query == "a.[0]" && s.is_empty()
307 ));
308 assert!(matches!(
309 &f(r#"a[0][1].b[2]="""#)[0],
310 Script { query, operation: Operation::Set(Value::String(s)), .. }
311 if query == "a.[0].[1].b.[2]" && s.is_empty()
312 ));
313
314 assert!(matches!(
316 &f(r#"'a'=0"#)[0],
317 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
318 if query == "'a'" && *n == 0
319 ));
320 assert!(matches!(
321 &f(r#""a"=0"#)[0],
322 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
323 if query == "\"a\"" && *n == 0
324 ));
325 assert!(matches!(
326 &f(r#"'a=b'=0"#)[0],
327 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
328 if query == "'a=b'" && *n == 0
329 ));
330 assert!(matches!(
331 &f(r#""a=b"=0"#)[0],
332 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
333 if query == "\"a=b\"" && *n == 0
334 ));
335
336 let r = &f("a=0\nb=1");
338 assert_eq!(r.len(), 2);
339 assert!(matches!(
340 &r[0],
341 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
342 if query == "a" && *n == 0
343 ));
344 assert!(matches!(
345 &r[1],
346 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
347 if query == "b" && *n == 1
348 ));
349 let r = &f("a=0;b=1");
350 assert_eq!(r.len(), 2);
351 assert!(matches!(
352 &r[0],
353 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
354 if query == "a" && *n == 0
355 ));
356 assert!(matches!(
357 &r[1],
358 Script { query, operation: Operation::Set(Value::Integer(n)), .. }
359 if query == "b" && *n == 1
360 ));
361 assert!(matches!(
362 &f("'a;b'='c;d'")[0],
363 Script { query, operation: Operation::Set(Value::String(s)), .. }
364 if query == "'a;b'" && s == "c;d"
365 ));
366 assert!(matches!(
367 &f(r#""a;b"="c;d""#)[0],
368 Script { query, operation: Operation::Set(Value::String(s)), .. }
369 if query == r#""a;b""# && s == "c;d"
370 ));
371 assert!(matches!(
372 &f("a=\"\"\"\nb\nc\n\"\"\"")[0],
373 Script { query, operation: Operation::Set(Value::String(s)), .. }
374 if query == "a" && s == "b\nc\n"
375 ));
376 }
377
378 #[test]
379 fn test_overwrite() {
380 #[track_caller]
381 fn read<'a>(v: &'a Value, q: &str) -> Option<&'a Value> {
382 v.read(&convert_query(q).unwrap()).unwrap()
383 }
384
385 let mut root_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
386 root_dir.pop(); let s = fs::read_to_string(
389 root_dir.join("openrr-apps/config/sample_robot_client_config_for_urdf_viz.toml"),
390 )
391 .unwrap();
392 let v: &Value = &toml::from_str(&s).unwrap();
393
394 {
395 let v = &mut v.clone();
397 overwrite(v, r#"urdf_viz_clients_configs[0].name = "a""#).unwrap();
398 assert_eq!(
399 read(v, "urdf_viz_clients_configs[0].name")
400 .unwrap()
401 .as_str()
402 .unwrap(),
403 "a"
404 )
405 }
406 {
407 let v = &mut v.clone();
409 overwrite(v, r#"urdf_viz_clients_configs[0].joint_names[0] = "b""#).unwrap();
410 assert_eq!(
411 read(v, "urdf_viz_clients_configs[0].joint_names[0]")
412 .unwrap()
413 .as_str()
414 .unwrap(),
415 "b"
416 )
417 }
418 {
419 let v = &mut v.clone();
421 overwrite(
422 v,
423 r#"urdf_viz_clients_configs[0].joint_position_limits[0] = {}"#,
424 )
425 .unwrap();
426 assert!(
427 read(v, "urdf_viz_clients_configs[0].joint_position_limits[0]")
428 .unwrap()
429 .as_table()
430 .unwrap()
431 .is_empty(),
432 )
433 }
434 {
435 let v = &mut v.clone();
437 overwrite(
438 v,
439 r#"urdf_viz_clients_configs[0].joint_position_limits[1] = { lower = -3.0, upper = 3.0 }"#,
440 )
441 .unwrap();
442 let t = read(v, "urdf_viz_clients_configs[0].joint_position_limits[1]")
443 .unwrap()
444 .as_table()
445 .unwrap();
446 assert_approx_eq!(t["lower"].as_float().unwrap(), -3.0);
447 assert_approx_eq!(t["upper"].as_float().unwrap(), 3.0);
448 }
449 {
450 let v = &mut v.clone();
452 overwrite(
453 v,
454 "urdf_viz_clients_configs[0].joint_position_limits[2].lower = 0.0",
455 )
456 .unwrap();
457 assert_approx_eq!(
458 read(
459 v,
460 "urdf_viz_clients_configs[0].joint_position_limits[2].lower"
461 )
462 .unwrap()
463 .as_float()
464 .unwrap(),
465 0.0
466 );
467 }
468 {
469 let v = &mut v.clone();
471 overwrite(v, "urdf_viz_clients_configs[0].joint_position_limits =").unwrap();
472 assert!(read(v, "urdf_viz_clients_configs[0].joint_position_limits").is_none());
473 }
474 {
475 let v = &mut v.clone();
477 overwrite(
478 v,
479 "openrr_clients_config.ik_solvers_configs.arm_ik_solver =",
480 )
481 .unwrap();
482 assert!(read(v, "openrr_clients_config.ik_solvers_configs.arm_ik_solver").is_none());
483 }
484 {
485 let v = &mut v.clone();
487 overwrite(v, "a.b.c =").unwrap();
488 }
489 {
490 let v = &mut v.clone();
492 overwrite(v, "urdf_viz_clients_configs[0].joint_names = [\n\"a\"\n]").unwrap();
493 assert_eq!(
494 *read(v, "urdf_viz_clients_configs[0].joint_names")
495 .unwrap()
496 .as_array()
497 .unwrap(),
498 vec![Value::String("a".into())]
499 )
500 }
501 {
502 let v = &mut v.clone();
504 overwrite(v, "urdf_viz_clients_configs[0].joint_names = [\n\"a\"]").unwrap();
505 assert_eq!(
506 *read(v, "urdf_viz_clients_configs[0].joint_names")
507 .unwrap()
508 .as_array()
509 .unwrap(),
510 vec![Value::String("a".into())]
511 )
512 }
513 {
514 let v = &mut v.clone();
516 overwrite(
517 v,
518 "urdf_viz_clients_configs[0].joint_names = [\n\"a\"]\ndummy=\"\"",
519 )
520 .unwrap();
521 assert_eq!(
522 *read(v, "urdf_viz_clients_configs[0].joint_names")
523 .unwrap()
524 .as_array()
525 .unwrap(),
526 vec![Value::String("a".into())]
527 )
528 }
529 {
530 let v = &mut v.clone();
532 overwrite(v, "a[0].b = 0").unwrap_err();
533 }
540 {
541 let v = &mut v.clone();
543 overwrite(
544 v,
545 "urdf_viz_clients_configs[0].wrap_with_joint_position_limiter =",
546 )
547 .unwrap();
548 assert!(v
549 .read("urdf_viz_clients_configs[0].wrap_with_joint_position_limiter")
550 .unwrap()
551 .is_none());
552 overwrite(
553 v,
554 "urdf_viz_clients_configs[0].wrap_with_joint_position_limiter = false",
555 )
556 .unwrap();
557 assert!(!read(
558 v,
559 "urdf_viz_clients_configs[0].wrap_with_joint_position_limiter"
560 )
561 .unwrap()
562 .as_bool()
563 .unwrap());
564 }
565
566 let s = r#"
567[gil_gamepad_config.map]
568axis_map = [
569 ["DPadX", "DPadX"],
570 ["LeftStickX", "LeftStickX"],
571 ["RightStickX", "RightStickX"],
572 ["RightStickY", "RightStickY"],
573 ["DPadY", "DPadY"],
574 ["LeftStickY", "LeftStickY"],
575]
576 "#;
577 let v: &Value = &toml::from_str(s).unwrap();
578 {
579 let v = &mut v.clone();
580 overwrite(v, r#"gil_gamepad_config.map.axis_map[0][0] = "DPadN""#).unwrap();
581 assert_eq!(
582 read(v, "gil_gamepad_config.map.axis_map[0][0]")
583 .unwrap()
584 .as_str()
585 .unwrap(),
586 "DPadN"
587 );
588 assert_eq!(
589 read(v, "gil_gamepad_config.map.axis_map[0][1]")
590 .unwrap()
591 .as_str()
592 .unwrap(),
593 "DPadX"
594 );
595 }
596 {
597 let v = &mut v.clone();
598 overwrite(
599 v,
600 "gil_gamepad_config.map.axis_map = [\n[\n\"DPadN\"\n]\n]\n",
601 )
602 .unwrap();
603 assert_eq!(
604 read(v, "gil_gamepad_config.map.axis_map[0][0]")
605 .unwrap()
606 .as_str()
607 .unwrap(),
608 "DPadN"
609 );
610 assert!(read(v, "gil_gamepad_config.map.axis_map[0]")
611 .unwrap()
612 .as_array()
613 .unwrap()
614 .get(1)
615 .is_none());
616 }
617 }
618}