git.delta.rocks / remowt / refs/commits / 62007a6e971b

difftreelog

source

crates/remowt-plugin/src/host.rs3.2 KiBsourcehistory
1use std::ffi::OsStr;2use std::io;3use std::process::Stdio;4use std::sync::Mutex;56use bifrostlink::{Port, Rpc, Rtt, WeakRpc};7use bytes::{Bytes, BytesMut};8use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};9use tokio::process::{Child, ChildStdin, ChildStdout, Command};1011use remowt_link_shared::plugin::{Error, PluginEndpoints, PluginHost};12use remowt_link_shared::{Address, BifConfig};1314pub fn serve(rpc: &mut Rpc<BifConfig>) {15	let host = Host {16		rpc: rpc.clone().downgrade(),17		children: Mutex::new(Vec::new()),18	};19	PluginEndpoints(host).register_endpoints(rpc);20}2122struct Host {23	rpc: WeakRpc<BifConfig>,24	children: Mutex<Vec<Child>>,25}2627impl Host {28	fn spawn(&self, id: u16, path: impl AsRef<OsStr>) -> Result<(), Error> {29		let rpc = self.rpc.clone().upgrade().ok_or(Error::Gone)?;3031		let mut child = Command::new(path)32			.arg(id.to_string())33			.stdin(Stdio::piped())34			.stdout(Stdio::piped())35			.kill_on_drop(true)36			.spawn()37			.map_err(|e| Error::Spawn(e.to_string()))?;38		let stdin = child.stdin.take().expect("stdin piped");39		let stdout = child.stdout.take().expect("stdout piped");4041		rpc.add_direct(Address::Plugin(id), child_port(stdout, stdin), Rtt(0));42		self.children.lock().expect("not poisoned").push(child);43		Ok(())44	}45}4647impl PluginHost for Host {48	async fn load_plugin(&self, id: u16, name: String) -> Result<(), Error> {49		// TODO: Right now loads plugin next to the binary...50		// But with our CA addressed schema, the plugins should be located in content-addressed subdir...51		// Maybe it should just be scrapped in favor of load_plugin_path.52		if name.is_empty() || name == "." || name == ".." || name.contains(['/', '\0']) {53			return Err(Error::BadName);54		}55		let exe = std::env::current_exe().map_err(|e| Error::Spawn(e.to_string()))?;56		let dir = exe57			.parent()58			.ok_or_else(|| Error::Spawn("primary agent has no parent directory".to_owned()))?;59		self.spawn(id, dir.join(&name))60	}6162	async fn load_plugin_path(&self, id: u16, path: String) -> Result<(), Error> {63		if path.is_empty() || path.contains('\0') {64			return Err(Error::BadName);65		}66		self.spawn(id, path)67	}68}6970fn child_port(mut stdout: ChildStdout, mut stdin: ChildStdin) -> Port {71	Port::new(|mut rx, tx| async move {72		let reader = async move {73			loop {74				let len = match stdout.read_u32().await {75					Ok(len) => len,76					Err(e) => {77						tracing::error!("plugin stdout read failed: {e}");78						break;79					}80				};81				let mut buf = BytesMut::zeroed(len as usize);82				if let Err(e) = stdout.read_exact(&mut buf).await {83					tracing::error!("plugin stdout read failed: {e}");84					break;85				}86				if tx.send(buf.freeze()).is_err() {87					break;88				}89			}90		};91		let writer = async move {92			while let Some(msg) = rx.recv().await {93				if let Err(e) = write_frame(&mut stdin, msg).await {94					tracing::error!("plugin stdin write failed: {e}");95					break;96				}97			}98		};99		tokio::join!(reader, writer);100	})101}102103async fn write_frame(stdin: &mut ChildStdin, msg: Bytes) -> io::Result<()> {104	let len = u32::try_from(msg.len())105		.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "message larger than 4GB"))?;106	stdin.write_u32(len).await?;107	stdin.write_all(&msg).await?;108	stdin.flush().await?;109	Ok(())110}