1use std::{ffi::OsStr, pin, process::Stdio, sync::Arc, task::Poll};23use anyhow::{anyhow, Result};4use better_command::{Handler, NixHandler, PlainHandler};5use futures::StreamExt;6use itertools::Either;7use openssh::{OverSsh, OwningCommand, Session};8use tokio::{io::AsyncRead, process::Command, select};9use tokio_util::codec::{BytesCodec, FramedRead, LinesCodec};10use tracing::debug;1112fn escape_bash(input: &str, out: &mut String) {13 const TO_ESCAPE: &str = "$ !\"#&'()*,;<>?[\\]^`{|}";14 if input.chars().all(|c| !TO_ESCAPE.contains(c)) {15 out.push_str(input);16 return;17 }18 out.push('\'');19 for (i, v) in input.split('\'').enumerate() {20 if i != 0 {21 out.push_str("'\"'\"'");22 }23 out.push_str(v);24 }25 out.push('\'');26}27fn ostoutf8(os: impl AsRef<OsStr>) -> String {28 os.as_ref().to_str().expect("non-utf8 data").to_owned()29}30#[derive(Clone)]31pub struct MyCommand {32 command: String,33 args: Vec<String>,34 env: Vec<(String, String)>,35 ssh_session: Option<Arc<Session>>,36}37impl MyCommand {38 pub fn new_on(cmd: impl AsRef<OsStr>, session: Arc<Session>) -> Self {39 assert!(!cmd.as_ref().is_empty());40 Self {41 command: ostoutf8(cmd),42 args: vec![],43 env: vec![],44 ssh_session: Some(session),45 }46 }47 pub fn new(cmd: impl AsRef<OsStr>) -> Self {48 assert!(!cmd.as_ref().is_empty());49 Self {50 command: ostoutf8(cmd),51 args: vec![],52 env: vec![],53 ssh_session: None,54 }55 }56 fn into_args(self) -> Vec<String> {57 let mut out = Vec::new();58 if !self.env.is_empty() {59 out.push("env".to_owned());60 for (k, v) in self.env {61 assert!(!k.contains('='));62 out.push(format!("{k}={v}"));63 }64 }65 out.push(self.command);66 out.extend(self.args);67 out68 }6970 71 72 73 74 75 fn translate_env_into_env(self) -> Self {76 if self.env.is_empty() {77 return self;78 }79 let mut out = Self::new("env");80 out.ssh_session = self.ssh_session;81 for (k, v) in self.env {82 assert!(!k.contains('='));83 out.arg(format!("{k}={v}"));84 }85 out.arg(self.command);86 out.args(self.args);8788 out89 }90 fn into_string(self) -> String {91 let mut out = String::new();92 if !self.env.is_empty() {93 out.push_str("env");94 for (k, v) in self.env {95 out.push(' ');96 assert!(!k.contains('='));97 escape_bash(&k, &mut out);98 out.push('=');99 escape_bash(&v, &mut out);100 }101 }102 if !out.is_empty() {103 out.push(' ');104 }105 escape_bash(&self.command, &mut out);106 for arg in self.args {107 out.push(' ');108 escape_bash(&arg, &mut out);109 }110 out111 }112 fn into_command(self) -> Command {113 let mut out = Command::new(self.command);114 out.args(self.args);115 for (k, v) in self.env {116 out.env(k, v);117 }118 out119 }120 fn into_command_new(self) -> Result<Either<Command, openssh::OwningCommand<Arc<Session>>>> {121 Ok(if let Some(session) = self.ssh_session.clone() {122 let cmd = self.translate_env_into_env().into_command();123 Either::Right(124 cmd.over_ssh(session)125 .map_err(|e| anyhow!("ssh error: {e}"))?,126 )127 } else {128 let cmd = self.into_command();129 Either::Left(cmd)130 })131 }132 pub fn arg(&mut self, arg: impl AsRef<OsStr>) -> &mut Self {133 let arg = arg.as_ref();134 self.args.push(ostoutf8(arg));135 self136 }137 pub fn eqarg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {138 let arg = arg.as_ref();139 let value = value.as_ref();140 let arg = ostoutf8(arg);141 let value = ostoutf8(value);142 self.arg(format!("{arg}={value}"));143 self144 }145 pub fn comparg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {146 self.arg(arg);147 self.arg(value);148 self149 }150 pub fn env(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> &mut Self {151 self.env152 .push((name.as_ref().to_owned(), value.as_ref().to_owned()));153 self154 }155 pub fn args<V: AsRef<OsStr>>(&mut self, args: impl IntoIterator<Item = V>) -> &mut Self {156 for arg in args.into_iter() {157 let arg = arg.as_ref();158 self.args.push(ostoutf8(arg));159 }160 self161 }162 pub fn sudo(mut self) -> Self {163 164 165 166 167 if std::env::var_os("NO_SUDO").is_some() {168 let mut out = Self::new("su");169 out.ssh_session = self.ssh_session.take();170 out.arg("-c").arg(self.into_string());171 out172 } else {173 let mut out = Self::new("sudo");174 out.ssh_session = self.ssh_session.take();175 out.args(self.into_args());176 out177 }178 }179180 pub async fn run(self) -> Result<()> {181 let str = self.clone().into_string();182 let cmd = self.into_command_new()?;183 match cmd {184 Either::Left(cmd) => run_nix_inner(str, cmd, &mut PlainHandler).await?,185 Either::Right(cmd) => run_nix_inner_ssh(str, cmd, &mut PlainHandler).await?,186 };187 Ok(())188 }189 pub async fn run_string(self) -> Result<String> {190 let bytes = self.run_bytes().await?;191 Ok(String::from_utf8(bytes)?)192 }193 pub async fn run_bytes(self) -> Result<Vec<u8>> {194 let str = self.clone().into_string();195 let cmd = self.into_command_new()?;196 let v = match cmd {197 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut PlainHandler).await?,198 Either::Right(cmd) => run_nix_inner_stdout_ssh(str, cmd, &mut PlainHandler).await?,199 };200 Ok(v)201 }202203 pub async fn run_nix_string(self) -> Result<String> {204 let str = self.clone().into_string();205 let mut cmd = self.into_command();206 cmd.arg("--log-format").arg("internal-json");207 let bytes = run_nix_inner_stdout(str, cmd, &mut NixHandler::default()).await?;208 Ok(String::from_utf8(bytes)?)209 }210 pub async fn run_nix(self) -> Result<()> {211 let str = self.clone().into_string();212 let mut cmd = self.into_command();213 cmd.arg("--log-format").arg("internal-json");214 cmd.stdout(Stdio::inherit());215 run_nix_inner(str, cmd, &mut NixHandler::default()).await216 }217}218219struct EmptyAsyncRead;220impl AsyncRead for EmptyAsyncRead {221 fn poll_read(222 self: std::pin::Pin<&mut Self>,223 _cx: &mut std::task::Context<'_>,224 _buf: &mut tokio::io::ReadBuf<'_>,225 ) -> Poll<std::io::Result<()>> {226 Poll::Pending227 }228}229230async fn run_nix_inner_stdout(231 str: String,232 cmd: Command,233 handler: &mut dyn Handler,234) -> Result<Vec<u8>> {235 Ok(run_nix_inner_raw(str, cmd, true, handler, None)236 .await?237 .expect("has out"))238}239async fn run_nix_inner(str: String, cmd: Command, handler: &mut dyn Handler) -> Result<()> {240 let v = run_nix_inner_raw(str, cmd, false, handler, None).await?;241 assert!(v.is_none());242 Ok(())243}244async fn run_nix_inner_stdout_ssh(245 str: String,246 cmd: OwningCommand<Arc<Session>>,247 handler: &mut dyn Handler,248) -> Result<Vec<u8>> {249 Ok(run_nix_inner_raw_ssh(str, cmd, true, handler, None)250 .await?251 .expect("has out"))252}253async fn run_nix_inner_ssh(254 str: String,255 cmd: OwningCommand<Arc<Session>>,256 handler: &mut dyn Handler,257) -> Result<()> {258 let v = run_nix_inner_raw_ssh(str, cmd, false, handler, None).await?;259 assert!(v.is_none());260 Ok(())261}262263async fn run_nix_inner_raw(264 str: String,265 mut cmd: Command,266 want_stdout: bool,267 err_handler: &mut dyn Handler,268 mut out_handler: Option<&mut dyn Handler>,269) -> Result<Option<Vec<u8>>> {270 cmd.stderr(Stdio::piped());271 cmd.stdout(Stdio::piped());272 debug!("running command {str:?} on local");273 let mut child = cmd.spawn()?;274 let mut stderr = child.stderr.take().unwrap();275 let stdout = child.stdout.take().unwrap();276 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());277 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));278 let mut ob = want_stdout279 .then(|| out.take().unwrap())280 .unwrap_or_else(|| Box::new(EmptyAsyncRead));281 let mut ol = (!want_stdout)282 .then(|| out.take().unwrap())283 .unwrap_or_else(|| Box::new(EmptyAsyncRead));284 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());285 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());286287 288289 let mut out_buf = if want_stdout { Some(vec![]) } else { None };290 loop {291 select! {292 e = err.next() => {293 if let Some(e) = e {294 let e = e?;295 err_handler.handle_line(&e);296 }297 },298 o = ob.next() => {299 if let Some(o) = o {300 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);301 }302 },303 o = ol.next() => {304 if let Some(o) = o {305 let o = o?;306 if let Some(out) = out_handler.as_mut() {307 out.handle_line(&o)308 } else {309 err_handler.handle_line(&o)310 }311 312 }313 },314 code = child.wait() => {315 let code = code?;316 if !code.success() {317 anyhow::bail!("command '{str}' failed with status {}", code);318 }319 break;320 }321 }322 }323324 Ok(out_buf)325}326async fn run_nix_inner_raw_ssh(327 str: String,328 mut cmd: OwningCommand<Arc<Session>>,329 want_stdout: bool,330 err_handler: &mut dyn Handler,331 mut out_handler: Option<&mut dyn Handler>,332) -> Result<Option<Vec<u8>>> {333 debug!("running command {str:?} over ssh");334 cmd.stderr(openssh::Stdio::piped());335 cmd.stdout(openssh::Stdio::piped());336 let mut child = cmd.spawn().await?;337 let mut stderr = child.stderr().take().unwrap();338 let stdout = child.stdout().take().unwrap();339 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());340 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));341 let mut ob = want_stdout342 .then(|| out.take().unwrap())343 .unwrap_or_else(|| Box::new(EmptyAsyncRead));344 let mut ol = (!want_stdout)345 .then(|| out.take().unwrap())346 .unwrap_or_else(|| Box::new(EmptyAsyncRead));347 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());348 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());349350 351352 let mut out_buf = if want_stdout { Some(vec![]) } else { None };353354 let mut wait_future = pin::pin!(child.wait());355 loop {356 select! {357 e = err.next() => {358 if let Some(e) = e {359 let e = e?;360 err_handler.handle_line(&e);361 }362 },363 o = ob.next() => {364 if let Some(o) = o {365 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);366 }367 },368 o = ol.next() => {369 if let Some(o) = o {370 let o = o?;371 if let Some(out) = out_handler.as_mut() {372 out.handle_line(&o)373 } else {374 err_handler.handle_line(&o)375 }376 377 }378 },379 code = &mut wait_future => {380 let code = code?;381 if !code.success() {382 anyhow::bail!("command '{str}' failed with status {}", code);383 }384 break;385 }386 }387 }388389 Ok(out_buf)390}