1use std::{ffi::OsStr, pin, process::Stdio, sync::Arc, task::Poll};23use anyhow::{Result, anyhow};4use better_command::{Handler, NixHandler, PlainHandler};5use futures::StreamExt;6use itertools::Either;7use openssh::{OverSsh, OwningCommand, Session};8use serde::de::DeserializeOwned;9use tokio::{io::AsyncRead, process::Command, select};10use tokio_util::codec::{BytesCodec, FramedRead, LinesCodec};11use tracing::debug;1213use crate::host::EscalationStrategy;1415fn escape_bash(input: &str, out: &mut String) {16 const TO_ESCAPE: &str = "$ !\"#&'()*,;<>?[\\]^`{|}";17 if input.chars().all(|c| !TO_ESCAPE.contains(c)) {18 out.push_str(input);19 return;20 }21 out.push('\'');22 for (i, v) in input.split('\'').enumerate() {23 if i != 0 {24 out.push_str("'\"'\"'");25 }26 out.push_str(v);27 }28 out.push('\'');29}30fn ostoutf8(os: impl AsRef<OsStr>) -> String {31 os.as_ref().to_str().expect("non-utf8 data").to_owned()32}3334#[derive(Clone, Debug)]35pub struct MyCommand {36 command: String,37 args: Vec<String>,38 env: Vec<(String, String)>,39 ssh_session: Option<Arc<Session>>,40 escalation: EscalationStrategy,41 escalate: bool,42}43impl MyCommand {44 pub fn new_on(45 escalation: EscalationStrategy,46 cmd: impl AsRef<OsStr>,47 session: Arc<Session>,48 ) -> Self {49 assert!(!cmd.as_ref().is_empty());50 Self {51 command: ostoutf8(cmd),52 args: vec![],53 env: vec![],54 ssh_session: Some(session),55 escalation,56 escalate: false,57 }58 }59 pub fn new(escalation: EscalationStrategy, cmd: impl AsRef<OsStr>) -> Self {60 assert!(!cmd.as_ref().is_empty());61 Self {62 command: ostoutf8(cmd),63 args: vec![],64 env: vec![],65 ssh_session: None,66 escalation,67 escalate: false,68 }69 }70 fn new_here(&self, cmd: impl AsRef<OsStr>) -> Self {71 match self.ssh_session.clone() {72 Some(ssh_session) => Self::new_on(self.escalation, cmd, ssh_session),73 _ => Self::new(self.escalation, cmd),74 }75 }7677 fn into_args(self) -> Vec<String> {78 let mut out = Vec::new();79 if !self.env.is_empty() {80 out.push("env".to_owned());81 for (k, v) in self.env {82 assert!(!k.contains('='));83 out.push(format!("{k}={v}"));84 }85 }86 out.push(self.command);87 out.extend(self.args);88 out89 }9091 92 93 94 95 96 fn translate_env_into_env(self) -> Self {97 if self.env.is_empty() {98 return self;99 }100 let mut out = self.new_here("env");101 for (k, v) in self.env {102 assert!(!k.contains('='));103 out.arg(format!("{k}={v}"));104 }105 out.arg(self.command);106 out.args(self.args);107108 out109 }110 fn into_string(self) -> String {111 let mut out = String::new();112 if !self.env.is_empty() {113 out.push_str("env");114 for (k, v) in self.env {115 out.push(' ');116 assert!(!k.contains('='));117 escape_bash(&k, &mut out);118 out.push('=');119 escape_bash(&v, &mut out);120 }121 }122 if !out.is_empty() {123 out.push(' ');124 }125 escape_bash(&self.command, &mut out);126 for arg in self.args {127 out.push(' ');128 escape_bash(&arg, &mut out);129 }130 out131 }132 fn into_command_unchecked_local(self) -> Command {133 let mut out = Command::new(self.command);134 out.args(self.args);135 for (k, v) in self.env {136 out.env(k, v);137 }138 out139 }140 fn into_command(self) -> Result<Either<Command, openssh::OwningCommand<Arc<Session>>>> {141 Ok(match self.ssh_session.clone() {142 Some(session) => {143 let cmd = self.translate_env_into_env().into_command_unchecked_local();144 Either::Right(145 cmd.over_ssh(session)146 .map_err(|e| anyhow!("ssh error: {e}"))?,147 )148 }149 _ => {150 let cmd = self.into_command_unchecked_local();151 Either::Left(cmd)152 }153 })154 }155 pub fn arg(&mut self, arg: impl AsRef<OsStr>) -> &mut Self {156 let arg = arg.as_ref();157 self.args.push(ostoutf8(arg));158 self159 }160 pub fn eqarg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {161 let arg = arg.as_ref();162 let value = value.as_ref();163 let arg = ostoutf8(arg);164 let value = ostoutf8(value);165 self.arg(format!("{arg}={value}"));166 self167 }168 pub fn comparg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {169 self.arg(arg);170 self.arg(value);171 self172 }173 pub fn env(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> &mut Self {174 self.env175 .push((name.as_ref().to_owned(), value.as_ref().to_owned()));176 self177 }178 pub fn args<V: AsRef<OsStr>>(&mut self, args: impl IntoIterator<Item = V>) -> &mut Self {179 for arg in args.into_iter() {180 let arg = arg.as_ref();181 self.args.push(ostoutf8(arg));182 }183 self184 }185 pub fn sudo(mut self) -> Self {186 self.escalate = true;187 self188 }189 fn wrap_sudo_if_needed(self) -> Self {190 if !self.escalate {191 return self;192 }193 match self.escalation {194 EscalationStrategy::Su => {195 let mut out = self.new_here("su");196 out.arg("-c").arg(self.into_string());197 out198 }199 EscalationStrategy::Sudo => {200 let mut out = self.new_here("sudo");201 out.args(self.into_args());202 out203 }204 EscalationStrategy::Run0 => {205 206 let mut run0 = self.new_here("run0");207 let mut out = self.new_here("script");208209 210 run0.arg("--background=");211 run0.args(self.into_args());212213 out.arg("-q");214 out.arg("/dev/null");215 out.arg("-c");216 out.arg(run0.into_string());217 dbg!(&out);218 out219 }220 }221 }222223 pub async fn run(self) -> Result<()> {224 let str = self.clone().into_string();225 let cmd = self.wrap_sudo_if_needed().into_command()?;226 match cmd {227 Either::Left(cmd) => run_nix_inner(str, cmd, &mut PlainHandler).await?,228 Either::Right(cmd) => run_nix_inner_ssh(str, cmd, &mut PlainHandler).await?,229 };230 Ok(())231 }232 pub async fn run_string(self) -> Result<String> {233 let bytes = self.run_bytes().await?;234 Ok(String::from_utf8(bytes)?)235 }236 pub async fn run_value<T: DeserializeOwned>(self) -> Result<T> {237 let v = self.run_string().await?;238 Ok(serde_json::from_str(&v)?)239 }240 pub async fn run_bytes(self) -> Result<Vec<u8>> {241 let str = self.clone().into_string();242 let cmd = self.wrap_sudo_if_needed().into_command()?;243 let v = match cmd {244 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut PlainHandler).await?,245 Either::Right(cmd) => run_nix_inner_stdout_ssh(str, cmd, &mut PlainHandler).await?,246 };247 Ok(v)248 }249250 pub async fn run_nix_string(mut self) -> Result<String> {251 let str = self.clone().into_string();252 self.arg("--log-format").arg("internal-json");253 let cmd = self.wrap_sudo_if_needed().into_command()?;254 let bytes = match cmd {255 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut NixHandler::default()).await?,256 Either::Right(cmd) => {257 run_nix_inner_stdout_ssh(str, cmd, &mut NixHandler::default()).await?258 }259 };260 Ok(String::from_utf8(bytes)?)261 }262 pub async fn run_nix(mut self) -> Result<()> {263 let str = self.clone().into_string();264 self.arg("--log-format").arg("internal-json");265 let cmd = self.wrap_sudo_if_needed().into_command()?;266 match cmd {267 Either::Left(mut cmd) => {268 cmd.stdout(Stdio::inherit());269 run_nix_inner(str, cmd, &mut NixHandler::default()).await270 }271 Either::Right(mut cmd) => {272 cmd.stdout(openssh::Stdio::inherit());273 run_nix_inner_ssh(str, cmd, &mut NixHandler::default()).await274 }275 }276 }277}278279struct EmptyAsyncRead;280impl AsyncRead for EmptyAsyncRead {281 fn poll_read(282 self: std::pin::Pin<&mut Self>,283 _cx: &mut std::task::Context<'_>,284 _buf: &mut tokio::io::ReadBuf<'_>,285 ) -> Poll<std::io::Result<()>> {286 Poll::Pending287 }288}289290async fn run_nix_inner_stdout(291 str: String,292 cmd: Command,293 handler: &mut dyn Handler,294) -> Result<Vec<u8>> {295 Ok(run_nix_inner_raw(str, cmd, true, handler, None)296 .await?297 .expect("has out"))298}299async fn run_nix_inner(str: String, cmd: Command, handler: &mut dyn Handler) -> Result<()> {300 let v = run_nix_inner_raw(str, cmd, false, handler, None).await?;301 assert!(v.is_none());302 Ok(())303}304async fn run_nix_inner_stdout_ssh(305 str: String,306 cmd: OwningCommand<Arc<Session>>,307 handler: &mut dyn Handler,308) -> Result<Vec<u8>> {309 Ok(run_nix_inner_raw_ssh(str, cmd, true, handler, None)310 .await?311 .expect("has out"))312}313async fn run_nix_inner_ssh(314 str: String,315 cmd: OwningCommand<Arc<Session>>,316 handler: &mut dyn Handler,317) -> Result<()> {318 let v = run_nix_inner_raw_ssh(str, cmd, false, handler, None).await?;319 assert!(v.is_none());320 Ok(())321}322323async fn run_nix_inner_raw(324 str: String,325 mut cmd: Command,326 want_stdout: bool,327 err_handler: &mut dyn Handler,328 mut out_handler: Option<&mut dyn Handler>,329) -> Result<Option<Vec<u8>>> {330 cmd.stderr(Stdio::piped());331 cmd.stdout(Stdio::piped());332 debug!("running command {str:?} on local");333 let mut child = cmd.spawn()?;334 let mut stderr = child.stderr.take().unwrap();335 let stdout = child.stdout.take().unwrap();336 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());337 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));338 let mut ob = want_stdout339 .then(|| out.take().unwrap())340 .unwrap_or_else(|| Box::new(EmptyAsyncRead));341 let mut ol = (!want_stdout)342 .then(|| out.take().unwrap())343 .unwrap_or_else(|| Box::new(EmptyAsyncRead));344 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());345 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());346347 348349 let mut out_buf = if want_stdout { Some(vec![]) } else { None };350 loop {351 select! {352 e = err.next() => {353 if let Some(e) = e {354 let e = e?;355 err_handler.handle_line(&e);356 }357 },358 o = ob.next() => {359 if let Some(o) = o {360 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);361 }362 },363 o = ol.next() => {364 if let Some(o) = o {365 let o = o?;366 if let Some(out) = out_handler.as_mut() {367 out.handle_line(&o)368 } else {369 err_handler.handle_line(&o)370 }371 372 }373 },374 code = child.wait() => {375 let code = code?;376 if !code.success() {377 anyhow::bail!("command '{str}' failed with status {}", code);378 }379 break;380 }381 }382 }383384 Ok(out_buf)385}386async fn run_nix_inner_raw_ssh(387 str: String,388 mut cmd: OwningCommand<Arc<Session>>,389 want_stdout: bool,390 err_handler: &mut dyn Handler,391 mut out_handler: Option<&mut dyn Handler>,392) -> Result<Option<Vec<u8>>> {393 debug!("running command {str:?} over ssh");394 cmd.stderr(openssh::Stdio::piped());395 cmd.stdout(openssh::Stdio::piped());396 let mut child = cmd.spawn().await?;397 let mut stderr = child.stderr().take().unwrap();398 let stdout = child.stdout().take().unwrap();399 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());400 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));401 let mut ob = want_stdout402 .then(|| out.take().unwrap())403 .unwrap_or_else(|| Box::new(EmptyAsyncRead));404 let mut ol = (!want_stdout)405 .then(|| out.take().unwrap())406 .unwrap_or_else(|| Box::new(EmptyAsyncRead));407 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());408 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());409410 411412 let mut out_buf = if want_stdout { Some(vec![]) } else { None };413414 let mut wait_future = pin::pin!(child.wait());415 loop {416 select! {417 e = err.next() => {418 if let Some(e) = e {419 let e = e?;420 err_handler.handle_line(&e);421 }422 },423 o = ob.next() => {424 if let Some(o) = o {425 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);426 }427 },428 o = ol.next() => {429 if let Some(o) = o {430 let o = o?;431 if let Some(out) = out_handler.as_mut() {432 out.handle_line(&o)433 } else {434 err_handler.handle_line(&o)435 }436 437 }438 },439 code = &mut wait_future => {440 let code = code?;441 if !code.success() {442 anyhow::bail!("command '{str}' failed with status {}", code);443 }444 break;445 }446 }447 }448449 Ok(out_buf)450}