1use std::{ffi::OsStr, pin, process::Stdio, sync::Arc, task::Poll};23use anyhow::{anyhow, Result};4use better_command::{Handler, NixHandler, PlainHandler};5use futures::StreamExt;6use itertools::Either;7use openssh::{OverSsh, OwningCommand, Session};8use tokio::{io::AsyncRead, process::Command, select};9use tokio_util::codec::{BytesCodec, FramedRead, LinesCodec};10use tracing::debug;1112use crate::host::EscalationStrategy;1314fn escape_bash(input: &str, out: &mut String) {15 const TO_ESCAPE: &str = "$ !\"#&'()*,;<>?[\\]^`{|}";16 if input.chars().all(|c| !TO_ESCAPE.contains(c)) {17 out.push_str(input);18 return;19 }20 out.push('\'');21 for (i, v) in input.split('\'').enumerate() {22 if i != 0 {23 out.push_str("'\"'\"'");24 }25 out.push_str(v);26 }27 out.push('\'');28}29fn ostoutf8(os: impl AsRef<OsStr>) -> String {30 os.as_ref().to_str().expect("non-utf8 data").to_owned()31}3233#[derive(Clone, Debug)]34pub struct MyCommand {35 command: String,36 args: Vec<String>,37 env: Vec<(String, String)>,38 ssh_session: Option<Arc<Session>>,39 escalation: EscalationStrategy,40 escalate: bool,41}42impl MyCommand {43 pub fn new_on(44 escalation: EscalationStrategy,45 cmd: impl AsRef<OsStr>,46 session: Arc<Session>,47 ) -> Self {48 assert!(!cmd.as_ref().is_empty());49 Self {50 command: ostoutf8(cmd),51 args: vec![],52 env: vec![],53 ssh_session: Some(session),54 escalation,55 escalate: false,56 }57 }58 pub fn new(escalation: EscalationStrategy, cmd: impl AsRef<OsStr>) -> Self {59 assert!(!cmd.as_ref().is_empty());60 Self {61 command: ostoutf8(cmd),62 args: vec![],63 env: vec![],64 ssh_session: None,65 escalation,66 escalate: false,67 }68 }69 fn new_here(&self, cmd: impl AsRef<OsStr>) -> Self {70 if let Some(ssh_session) = self.ssh_session.clone() {71 Self::new_on(self.escalation, cmd, ssh_session)72 } else {73 Self::new(self.escalation, cmd)74 }75 }7677 fn into_args(self) -> Vec<String> {78 let mut out = Vec::new();79 if !self.env.is_empty() {80 out.push("env".to_owned());81 for (k, v) in self.env {82 assert!(!k.contains('='));83 out.push(format!("{k}={v}"));84 }85 }86 out.push(self.command);87 out.extend(self.args);88 out89 }9091 92 93 94 95 96 fn translate_env_into_env(self) -> Self {97 if self.env.is_empty() {98 return self;99 }100 let mut out = self.new_here("env");101 for (k, v) in self.env {102 assert!(!k.contains('='));103 out.arg(format!("{k}={v}"));104 }105 out.arg(self.command);106 out.args(self.args);107108 out109 }110 fn into_string(self) -> String {111 let mut out = String::new();112 if !self.env.is_empty() {113 out.push_str("env");114 for (k, v) in self.env {115 out.push(' ');116 assert!(!k.contains('='));117 escape_bash(&k, &mut out);118 out.push('=');119 escape_bash(&v, &mut out);120 }121 }122 if !out.is_empty() {123 out.push(' ');124 }125 escape_bash(&self.command, &mut out);126 for arg in self.args {127 out.push(' ');128 escape_bash(&arg, &mut out);129 }130 out131 }132 fn into_command_unchecked_local(self) -> Command {133 let mut out = Command::new(self.command);134 out.args(self.args);135 for (k, v) in self.env {136 out.env(k, v);137 }138 out139 }140 fn into_command(self) -> Result<Either<Command, openssh::OwningCommand<Arc<Session>>>> {141 Ok(if let Some(session) = self.ssh_session.clone() {142 let cmd = self.translate_env_into_env().into_command_unchecked_local();143 Either::Right(144 cmd.over_ssh(session)145 .map_err(|e| anyhow!("ssh error: {e}"))?,146 )147 } else {148 let cmd = self.into_command_unchecked_local();149 Either::Left(cmd)150 })151 }152 pub fn arg(&mut self, arg: impl AsRef<OsStr>) -> &mut Self {153 let arg = arg.as_ref();154 self.args.push(ostoutf8(arg));155 self156 }157 pub fn eqarg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {158 let arg = arg.as_ref();159 let value = value.as_ref();160 let arg = ostoutf8(arg);161 let value = ostoutf8(value);162 self.arg(format!("{arg}={value}"));163 self164 }165 pub fn comparg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {166 self.arg(arg);167 self.arg(value);168 self169 }170 pub fn env(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> &mut Self {171 self.env172 .push((name.as_ref().to_owned(), value.as_ref().to_owned()));173 self174 }175 pub fn args<V: AsRef<OsStr>>(&mut self, args: impl IntoIterator<Item = V>) -> &mut Self {176 for arg in args.into_iter() {177 let arg = arg.as_ref();178 self.args.push(ostoutf8(arg));179 }180 self181 }182 pub fn sudo(mut self) -> Self {183 self.escalate = true;184 self185 }186 fn wrap_sudo_if_needed(self) -> Self {187 if !self.escalate {188 return self;189 }190 match self.escalation {191 EscalationStrategy::Su => {192 let mut out = self.new_here("su");193 out.arg("-c").arg(self.into_string());194 out195 }196 EscalationStrategy::Sudo => {197 let mut out = self.new_here("sudo");198 out.args(self.into_args());199 out200 }201 EscalationStrategy::Run0 => {202 203 let mut run0 = self.new_here("run0");204 let mut out = self.new_here("script");205206 207 run0.arg("--background=");208 run0.args(self.into_args());209210 out.arg("-q");211 out.arg("/dev/null");212 out.arg("-c");213 out.arg(run0.into_string());214 dbg!(&out);215 out216 }217 }218 }219220 pub async fn run(self) -> Result<()> {221 let str = self.clone().into_string();222 let cmd = self.wrap_sudo_if_needed().into_command()?;223 match cmd {224 Either::Left(cmd) => run_nix_inner(str, cmd, &mut PlainHandler).await?,225 Either::Right(cmd) => run_nix_inner_ssh(str, cmd, &mut PlainHandler).await?,226 };227 Ok(())228 }229 pub async fn run_string(self) -> Result<String> {230 let bytes = self.run_bytes().await?;231 Ok(String::from_utf8(bytes)?)232 }233 pub async fn run_bytes(self) -> Result<Vec<u8>> {234 let str = self.clone().into_string();235 let cmd = self.wrap_sudo_if_needed().into_command()?;236 let v = match cmd {237 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut PlainHandler).await?,238 Either::Right(cmd) => run_nix_inner_stdout_ssh(str, cmd, &mut PlainHandler).await?,239 };240 Ok(v)241 }242243 pub async fn run_nix_string(mut self) -> Result<String> {244 let str = self.clone().into_string();245 self.arg("--log-format").arg("internal-json");246 let cmd = self.wrap_sudo_if_needed().into_command()?;247 let bytes = match cmd {248 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut NixHandler::default()).await?,249 Either::Right(cmd) => {250 run_nix_inner_stdout_ssh(str, cmd, &mut NixHandler::default()).await?251 }252 };253 Ok(String::from_utf8(bytes)?)254 }255 pub async fn run_nix(mut self) -> Result<()> {256 let str = self.clone().into_string();257 self.arg("--log-format").arg("internal-json");258 let cmd = self.wrap_sudo_if_needed().into_command()?;259 match cmd {260 Either::Left(mut cmd) => {261 cmd.stdout(Stdio::inherit());262 run_nix_inner(str, cmd, &mut NixHandler::default()).await263 }264 Either::Right(mut cmd) => {265 cmd.stdout(openssh::Stdio::inherit());266 run_nix_inner_ssh(str, cmd, &mut NixHandler::default()).await267 }268 }269 }270}271272struct EmptyAsyncRead;273impl AsyncRead for EmptyAsyncRead {274 fn poll_read(275 self: std::pin::Pin<&mut Self>,276 _cx: &mut std::task::Context<'_>,277 _buf: &mut tokio::io::ReadBuf<'_>,278 ) -> Poll<std::io::Result<()>> {279 Poll::Pending280 }281}282283async fn run_nix_inner_stdout(284 str: String,285 cmd: Command,286 handler: &mut dyn Handler,287) -> Result<Vec<u8>> {288 Ok(run_nix_inner_raw(str, cmd, true, handler, None)289 .await?290 .expect("has out"))291}292async fn run_nix_inner(str: String, cmd: Command, handler: &mut dyn Handler) -> Result<()> {293 let v = run_nix_inner_raw(str, cmd, false, handler, None).await?;294 assert!(v.is_none());295 Ok(())296}297async fn run_nix_inner_stdout_ssh(298 str: String,299 cmd: OwningCommand<Arc<Session>>,300 handler: &mut dyn Handler,301) -> Result<Vec<u8>> {302 Ok(run_nix_inner_raw_ssh(str, cmd, true, handler, None)303 .await?304 .expect("has out"))305}306async fn run_nix_inner_ssh(307 str: String,308 cmd: OwningCommand<Arc<Session>>,309 handler: &mut dyn Handler,310) -> Result<()> {311 let v = run_nix_inner_raw_ssh(str, cmd, false, handler, None).await?;312 assert!(v.is_none());313 Ok(())314}315316async fn run_nix_inner_raw(317 str: String,318 mut cmd: Command,319 want_stdout: bool,320 err_handler: &mut dyn Handler,321 mut out_handler: Option<&mut dyn Handler>,322) -> Result<Option<Vec<u8>>> {323 cmd.stderr(Stdio::piped());324 cmd.stdout(Stdio::piped());325 debug!("running command {str:?} on local");326 let mut child = cmd.spawn()?;327 let mut stderr = child.stderr.take().unwrap();328 let stdout = child.stdout.take().unwrap();329 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());330 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));331 let mut ob = want_stdout332 .then(|| out.take().unwrap())333 .unwrap_or_else(|| Box::new(EmptyAsyncRead));334 let mut ol = (!want_stdout)335 .then(|| out.take().unwrap())336 .unwrap_or_else(|| Box::new(EmptyAsyncRead));337 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());338 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());339340 341342 let mut out_buf = if want_stdout { Some(vec![]) } else { None };343 loop {344 select! {345 e = err.next() => {346 if let Some(e) = e {347 let e = e?;348 err_handler.handle_line(&e);349 }350 },351 o = ob.next() => {352 if let Some(o) = o {353 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);354 }355 },356 o = ol.next() => {357 if let Some(o) = o {358 let o = o?;359 if let Some(out) = out_handler.as_mut() {360 out.handle_line(&o)361 } else {362 err_handler.handle_line(&o)363 }364 365 }366 },367 code = child.wait() => {368 let code = code?;369 if !code.success() {370 anyhow::bail!("command '{str}' failed with status {}", code);371 }372 break;373 }374 }375 }376377 Ok(out_buf)378}379async fn run_nix_inner_raw_ssh(380 str: String,381 mut cmd: OwningCommand<Arc<Session>>,382 want_stdout: bool,383 err_handler: &mut dyn Handler,384 mut out_handler: Option<&mut dyn Handler>,385) -> Result<Option<Vec<u8>>> {386 debug!("running command {str:?} over ssh");387 cmd.stderr(openssh::Stdio::piped());388 cmd.stdout(openssh::Stdio::piped());389 let mut child = cmd.spawn().await?;390 let mut stderr = child.stderr().take().unwrap();391 let stdout = child.stdout().take().unwrap();392 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());393 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));394 let mut ob = want_stdout395 .then(|| out.take().unwrap())396 .unwrap_or_else(|| Box::new(EmptyAsyncRead));397 let mut ol = (!want_stdout)398 .then(|| out.take().unwrap())399 .unwrap_or_else(|| Box::new(EmptyAsyncRead));400 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());401 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());402403 404405 let mut out_buf = if want_stdout { Some(vec![]) } else { None };406407 let mut wait_future = pin::pin!(child.wait());408 loop {409 select! {410 e = err.next() => {411 if let Some(e) = e {412 let e = e?;413 err_handler.handle_line(&e);414 }415 },416 o = ob.next() => {417 if let Some(o) = o {418 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);419 }420 },421 o = ol.next() => {422 if let Some(o) = o {423 let o = o?;424 if let Some(out) = out_handler.as_mut() {425 out.handle_line(&o)426 } else {427 err_handler.handle_line(&o)428 }429 430 }431 },432 code = &mut wait_future => {433 let code = code?;434 if !code.success() {435 anyhow::bail!("command '{str}' failed with status {}", code);436 }437 break;438 }439 }440 }441442 Ok(out_buf)443}