git.delta.rocks / remowt / refs/commits / 6d9cf16dada2

difftreelog

source

crates/remowt-endpoints/src/subprocess.rs6.9 KiBsourcehistory
1use std::collections::HashMap;2use std::io;3use std::process::Stdio;4use std::sync::atomic::{AtomicU64, Ordering};5use std::sync::{Arc, Mutex};67use bifrostlink::declarative::endpoints;8use bifrostlink::Config;9use camino::Utf8PathBuf;10use nix::sys::signal::{self, Signal};11use nix::unistd::Pid;12use serde::{Deserialize, Serialize};13use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};14use tokio::net::UnixStream;15use tokio::process::{ChildStderr, ChildStdout, Command};16use tokio::sync::{mpsc, watch};17use tracing::{debug, warn};1819pub type ProcId = u64;2021#[derive(Serialize, Deserialize, Debug)]22pub enum StdioSpec {23	Null,24	Socket(Utf8PathBuf),25}2627#[derive(Serialize, Deserialize, Debug)]28pub enum StderrSpec {29	Null,30	Socket(Utf8PathBuf),31	MergeWithStdout,32}3334#[derive(Serialize, Deserialize, Debug)]35pub struct SpawnSpec {36	pub program: String,37	pub args: Vec<String>,38	pub env: Vec<(String, String)>,39	pub env_clear: bool,40	pub cwd: Option<Utf8PathBuf>,41	pub stdin: StdioSpec,42	pub stdout: StdioSpec,43	pub stderr: StderrSpec,44}4546#[derive(Serialize, Deserialize, Debug, thiserror::Error)]47pub enum Error {48	#[error("spawn failed: {0}")]49	Spawn(String),50	#[error("connect to forwarded socket failed: {0}")]51	Connect(String),52	#[error("no process with that id")]53	NoSuchProcess,54	#[error("MergeWithStdout requires stdout=Socket")]55	BadMerge,56	#[error("invalid signal: {0}")]57	BadSignal(i32),58	#[error("kill failed: {0}")]59	Kill(String),60	#[error("io error: {0}")]61	Io(String),62}6364impl From<io::Error> for Error {65	fn from(e: io::Error) -> Self {66		Error::Io(e.to_string())67	}68}6970struct ChildState {71	pid: u32,72	exit_rx: watch::Receiver<Option<Option<i32>>>,73}7475#[derive(Clone, Default)]76pub struct Subprocess {77	children: Arc<Mutex<HashMap<ProcId, ChildState>>>,78	next_id: Arc<AtomicU64>,79}8081impl Subprocess {82	pub fn new() -> Self {83		Self::default()84	}85}8687#[endpoints(ns = 10)]88impl Subprocess {89	#[endpoints(id = 1)]90	async fn spawn(&self, spec: SpawnSpec) -> Result<ProcId, Error> {91		let SpawnSpec {92			program,93			args,94			env,95			env_clear,96			cwd,97			stdin,98			stdout,99			stderr,100		} = spec;101102		if matches!(stderr, StderrSpec::MergeWithStdout) && !matches!(stdout, StdioSpec::Socket(_))103		{104			return Err(Error::BadMerge);105		}106107		let mut cmd = Command::new(&program);108		cmd.args(&args);109		if env_clear {110			cmd.env_clear();111		}112		for (k, v) in &env {113			cmd.env(k, v);114		}115		if let Some(cwd) = &cwd {116			cmd.current_dir(cwd);117		}118		cmd.stdin(match &stdin {119			StdioSpec::Socket(_) => Stdio::piped(),120			StdioSpec::Null => Stdio::null(),121		});122		cmd.stdout(match &stdout {123			StdioSpec::Socket(_) => Stdio::piped(),124			StdioSpec::Null => Stdio::null(),125		});126		cmd.stderr(match &stderr {127			StderrSpec::Socket(_) | StderrSpec::MergeWithStdout => Stdio::piped(),128			StderrSpec::Null => Stdio::null(),129		});130		cmd.kill_on_drop(false);131132		let mut child = cmd.spawn().map_err(|e| Error::Spawn(e.to_string()))?;133		let pid = child134			.id()135			.ok_or_else(|| Error::Spawn("child exited before pid available".to_owned()))?;136137		if let StdioSpec::Socket(path) = &stdin {138			let sock = UnixStream::connect(path)139				.await140				.map_err(|e| Error::Connect(e.to_string()))?;141			let mut stdin_w = child.stdin.take().expect("piped");142			tokio::spawn(async move {143				let (mut sr, _) = tokio::io::split(sock);144				let _ = tokio::io::copy(&mut sr, &mut stdin_w).await;145				let _ = stdin_w.shutdown().await;146			});147		}148149		let stdout_handle = child.stdout.take();150		let stderr_handle = child.stderr.take();151152		match (&stdout, &stderr, stdout_handle, stderr_handle) {153			(StdioSpec::Socket(out_path), StderrSpec::MergeWithStdout, Some(out), Some(err)) => {154				let sock = UnixStream::connect(out_path)155					.await156					.map_err(|e| Error::Connect(e.to_string()))?;157				tokio::spawn(merge_to_sock(out, err, sock));158			}159			(StdioSpec::Socket(out_path), _, Some(out), err_opt) => {160				let sock = UnixStream::connect(out_path)161					.await162					.map_err(|e| Error::Connect(e.to_string()))?;163				tokio::spawn(pump_to_sock(out, sock));164				if let (StderrSpec::Socket(err_path), Some(err)) = (&stderr, err_opt) {165					let err_sock = UnixStream::connect(err_path)166						.await167						.map_err(|e| Error::Connect(e.to_string()))?;168					tokio::spawn(pump_to_sock(err, err_sock));169				}170			}171			(StdioSpec::Null, StderrSpec::Socket(err_path), _, Some(err)) => {172				let sock = UnixStream::connect(err_path)173					.await174					.map_err(|e| Error::Connect(e.to_string()))?;175				tokio::spawn(pump_to_sock(err, sock));176			}177			_ => {}178		}179180		let (exit_tx, exit_rx) = watch::channel(None);181		let id = self.next_id.fetch_add(1, Ordering::Relaxed);182		self.children183			.lock()184			.expect("not poisoned")185			.insert(id, ChildState { pid, exit_rx });186187		debug!(id, pid, program, "subprocess spawned");188		tokio::spawn(async move {189			let result = child.wait().await;190			let code = match result {191				Ok(status) => status.code(),192				Err(e) => {193					warn!(id, "child.wait failed: {e}");194					None195				}196			};197			let _ = exit_tx.send(Some(code));198		});199200		Ok(id)201	}202203	#[endpoints(id = 2)]204	async fn wait(&self, id: ProcId) -> Result<Option<i32>, Error> {205		let mut rx = {206			let map = self.children.lock().expect("not poisoned");207			let entry = map.get(&id).ok_or(Error::NoSuchProcess)?;208			entry.exit_rx.clone()209		};210		rx.wait_for(|v| v.is_some())211			.await212			.map_err(|_| Error::Io("exit channel closed".to_owned()))?;213		let code = rx.borrow().flatten();214		self.children.lock().expect("not poisoned").remove(&id);215		Ok(code)216	}217218	#[endpoints(id = 3)]219	async fn kill(&self, id: ProcId, signal: i32) -> Result<(), Error> {220		let pid = {221			let map = self.children.lock().expect("not poisoned");222			let entry = map.get(&id).ok_or(Error::NoSuchProcess)?;223			entry.pid224		};225		let sig = Signal::try_from(signal).map_err(|_| Error::BadSignal(signal))?;226		signal::kill(Pid::from_raw(pid as i32), sig).map_err(|e| Error::Kill(e.to_string()))?;227		Ok(())228	}229}230231async fn pump_to_sock<R>(mut from: R, sock: UnixStream)232where233	R: tokio::io::AsyncRead + Unpin,234{235	let (_, mut sw) = tokio::io::split(sock);236	let _ = tokio::io::copy(&mut from, &mut sw).await;237	let _ = sw.shutdown().await;238}239240async fn merge_to_sock(mut stdout: ChildStdout, mut stderr: ChildStderr, sock: UnixStream) {241	let (_, mut sw) = tokio::io::split(sock);242	let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);243	let tx_out = tx.clone();244	let out_pump = tokio::spawn(async move {245		let mut buf = vec![0u8; 4096];246		loop {247			match stdout.read(&mut buf).await {248				Ok(0) | Err(_) => break,249				Ok(n) => {250					if tx_out.send(buf[..n].to_vec()).await.is_err() {251						break;252					}253				}254			}255		}256	});257	let err_pump = tokio::spawn(async move {258		let mut buf = vec![0u8; 4096];259		loop {260			match stderr.read(&mut buf).await {261				Ok(0) | Err(_) => break,262				Ok(n) => {263					if tx.send(buf[..n].to_vec()).await.is_err() {264						break;265					}266				}267			}268		}269	});270	while let Some(chunk) = rx.recv().await {271		if sw.write_all(&chunk).await.is_err() {272			break;273		}274	}275	let _ = out_pump.await;276	let _ = err_pump.await;277	let _ = sw.shutdown().await;278}