1use std::{ffi::OsStr, pin, process::Stdio, sync::Arc, task::Poll};23use anyhow::{anyhow, Result};4use better_command::{Handler, NixHandler, PlainHandler};5use futures::StreamExt;6use itertools::Either;7use openssh::{OverSsh, OwningCommand, Session};8use serde::de::DeserializeOwned;9use tokio::{io::AsyncRead, process::Command, select};10use tokio_util::codec::{BytesCodec, FramedRead, LinesCodec};11use tracing::debug;1213use crate::host::EscalationStrategy;1415fn escape_bash(input: &str, out: &mut String) {16 const TO_ESCAPE: &str = "$ !\"#&'()*,;<>?[\\]^`{|}";17 if input.chars().all(|c| !TO_ESCAPE.contains(c)) {18 out.push_str(input);19 return;20 }21 out.push('\'');22 for (i, v) in input.split('\'').enumerate() {23 if i != 0 {24 out.push_str("'\"'\"'");25 }26 out.push_str(v);27 }28 out.push('\'');29}30fn ostoutf8(os: impl AsRef<OsStr>) -> String {31 os.as_ref().to_str().expect("non-utf8 data").to_owned()32}3334#[derive(Clone, Debug)]35pub struct MyCommand {36 command: String,37 args: Vec<String>,38 env: Vec<(String, String)>,39 ssh_session: Option<Arc<Session>>,40 escalation: EscalationStrategy,41 escalate: bool,42}43impl MyCommand {44 pub fn new_on(45 escalation: EscalationStrategy,46 cmd: impl AsRef<OsStr>,47 session: Arc<Session>,48 ) -> Self {49 assert!(!cmd.as_ref().is_empty());50 Self {51 command: ostoutf8(cmd),52 args: vec![],53 env: vec![],54 ssh_session: Some(session),55 escalation,56 escalate: false,57 }58 }59 pub fn new(escalation: EscalationStrategy, cmd: impl AsRef<OsStr>) -> Self {60 assert!(!cmd.as_ref().is_empty());61 Self {62 command: ostoutf8(cmd),63 args: vec![],64 env: vec![],65 ssh_session: None,66 escalation,67 escalate: false,68 }69 }70 fn new_here(&self, cmd: impl AsRef<OsStr>) -> Self {71 if let Some(ssh_session) = self.ssh_session.clone() {72 Self::new_on(self.escalation, cmd, ssh_session)73 } else {74 Self::new(self.escalation, cmd)75 }76 }7778 fn into_args(self) -> Vec<String> {79 let mut out = Vec::new();80 if !self.env.is_empty() {81 out.push("env".to_owned());82 for (k, v) in self.env {83 assert!(!k.contains('='));84 out.push(format!("{k}={v}"));85 }86 }87 out.push(self.command);88 out.extend(self.args);89 out90 }9192 93 94 95 96 97 fn translate_env_into_env(self) -> Self {98 if self.env.is_empty() {99 return self;100 }101 let mut out = self.new_here("env");102 for (k, v) in self.env {103 assert!(!k.contains('='));104 out.arg(format!("{k}={v}"));105 }106 out.arg(self.command);107 out.args(self.args);108109 out110 }111 fn into_string(self) -> String {112 let mut out = String::new();113 if !self.env.is_empty() {114 out.push_str("env");115 for (k, v) in self.env {116 out.push(' ');117 assert!(!k.contains('='));118 escape_bash(&k, &mut out);119 out.push('=');120 escape_bash(&v, &mut out);121 }122 }123 if !out.is_empty() {124 out.push(' ');125 }126 escape_bash(&self.command, &mut out);127 for arg in self.args {128 out.push(' ');129 escape_bash(&arg, &mut out);130 }131 out132 }133 fn into_command_unchecked_local(self) -> Command {134 let mut out = Command::new(self.command);135 out.args(self.args);136 for (k, v) in self.env {137 out.env(k, v);138 }139 out140 }141 fn into_command(self) -> Result<Either<Command, openssh::OwningCommand<Arc<Session>>>> {142 Ok(if let Some(session) = self.ssh_session.clone() {143 let cmd = self.translate_env_into_env().into_command_unchecked_local();144 Either::Right(145 cmd.over_ssh(session)146 .map_err(|e| anyhow!("ssh error: {e}"))?,147 )148 } else {149 let cmd = self.into_command_unchecked_local();150 Either::Left(cmd)151 })152 }153 pub fn arg(&mut self, arg: impl AsRef<OsStr>) -> &mut Self {154 let arg = arg.as_ref();155 self.args.push(ostoutf8(arg));156 self157 }158 pub fn eqarg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {159 let arg = arg.as_ref();160 let value = value.as_ref();161 let arg = ostoutf8(arg);162 let value = ostoutf8(value);163 self.arg(format!("{arg}={value}"));164 self165 }166 pub fn comparg(&mut self, arg: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> &mut Self {167 self.arg(arg);168 self.arg(value);169 self170 }171 pub fn env(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> &mut Self {172 self.env173 .push((name.as_ref().to_owned(), value.as_ref().to_owned()));174 self175 }176 pub fn args<V: AsRef<OsStr>>(&mut self, args: impl IntoIterator<Item = V>) -> &mut Self {177 for arg in args.into_iter() {178 let arg = arg.as_ref();179 self.args.push(ostoutf8(arg));180 }181 self182 }183 pub fn sudo(mut self) -> Self {184 self.escalate = true;185 self186 }187 fn wrap_sudo_if_needed(self) -> Self {188 if !self.escalate {189 return self;190 }191 match self.escalation {192 EscalationStrategy::Su => {193 let mut out = self.new_here("su");194 out.arg("-c").arg(self.into_string());195 out196 }197 EscalationStrategy::Sudo => {198 let mut out = self.new_here("sudo");199 out.args(self.into_args());200 out201 }202 EscalationStrategy::Run0 => {203 204 let mut run0 = self.new_here("run0");205 let mut out = self.new_here("script");206207 208 run0.arg("--background=");209 run0.args(self.into_args());210211 out.arg("-q");212 out.arg("/dev/null");213 out.arg("-c");214 out.arg(run0.into_string());215 dbg!(&out);216 out217 }218 }219 }220221 pub async fn run(self) -> Result<()> {222 let str = self.clone().into_string();223 let cmd = self.wrap_sudo_if_needed().into_command()?;224 match cmd {225 Either::Left(cmd) => run_nix_inner(str, cmd, &mut PlainHandler).await?,226 Either::Right(cmd) => run_nix_inner_ssh(str, cmd, &mut PlainHandler).await?,227 };228 Ok(())229 }230 pub async fn run_string(self) -> Result<String> {231 let bytes = self.run_bytes().await?;232 Ok(String::from_utf8(bytes)?)233 }234 pub async fn run_value<T: DeserializeOwned>(self) -> Result<T> {235 let v = self.run_string().await?;236 Ok(serde_json::from_str(&v)?)237 }238 pub async fn run_bytes(self) -> Result<Vec<u8>> {239 let str = self.clone().into_string();240 let cmd = self.wrap_sudo_if_needed().into_command()?;241 let v = match cmd {242 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut PlainHandler).await?,243 Either::Right(cmd) => run_nix_inner_stdout_ssh(str, cmd, &mut PlainHandler).await?,244 };245 Ok(v)246 }247248 pub async fn run_nix_string(mut self) -> Result<String> {249 let str = self.clone().into_string();250 self.arg("--log-format").arg("internal-json");251 let cmd = self.wrap_sudo_if_needed().into_command()?;252 let bytes = match cmd {253 Either::Left(cmd) => run_nix_inner_stdout(str, cmd, &mut NixHandler::default()).await?,254 Either::Right(cmd) => {255 run_nix_inner_stdout_ssh(str, cmd, &mut NixHandler::default()).await?256 }257 };258 Ok(String::from_utf8(bytes)?)259 }260 pub async fn run_nix(mut self) -> Result<()> {261 let str = self.clone().into_string();262 self.arg("--log-format").arg("internal-json");263 let cmd = self.wrap_sudo_if_needed().into_command()?;264 match cmd {265 Either::Left(mut cmd) => {266 cmd.stdout(Stdio::inherit());267 run_nix_inner(str, cmd, &mut NixHandler::default()).await268 }269 Either::Right(mut cmd) => {270 cmd.stdout(openssh::Stdio::inherit());271 run_nix_inner_ssh(str, cmd, &mut NixHandler::default()).await272 }273 }274 }275}276277struct EmptyAsyncRead;278impl AsyncRead for EmptyAsyncRead {279 fn poll_read(280 self: std::pin::Pin<&mut Self>,281 _cx: &mut std::task::Context<'_>,282 _buf: &mut tokio::io::ReadBuf<'_>,283 ) -> Poll<std::io::Result<()>> {284 Poll::Pending285 }286}287288async fn run_nix_inner_stdout(289 str: String,290 cmd: Command,291 handler: &mut dyn Handler,292) -> Result<Vec<u8>> {293 Ok(run_nix_inner_raw(str, cmd, true, handler, None)294 .await?295 .expect("has out"))296}297async fn run_nix_inner(str: String, cmd: Command, handler: &mut dyn Handler) -> Result<()> {298 let v = run_nix_inner_raw(str, cmd, false, handler, None).await?;299 assert!(v.is_none());300 Ok(())301}302async fn run_nix_inner_stdout_ssh(303 str: String,304 cmd: OwningCommand<Arc<Session>>,305 handler: &mut dyn Handler,306) -> Result<Vec<u8>> {307 Ok(run_nix_inner_raw_ssh(str, cmd, true, handler, None)308 .await?309 .expect("has out"))310}311async fn run_nix_inner_ssh(312 str: String,313 cmd: OwningCommand<Arc<Session>>,314 handler: &mut dyn Handler,315) -> Result<()> {316 let v = run_nix_inner_raw_ssh(str, cmd, false, handler, None).await?;317 assert!(v.is_none());318 Ok(())319}320321async fn run_nix_inner_raw(322 str: String,323 mut cmd: Command,324 want_stdout: bool,325 err_handler: &mut dyn Handler,326 mut out_handler: Option<&mut dyn Handler>,327) -> Result<Option<Vec<u8>>> {328 cmd.stderr(Stdio::piped());329 cmd.stdout(Stdio::piped());330 debug!("running command {str:?} on local");331 let mut child = cmd.spawn()?;332 let mut stderr = child.stderr.take().unwrap();333 let stdout = child.stdout.take().unwrap();334 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());335 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));336 let mut ob = want_stdout337 .then(|| out.take().unwrap())338 .unwrap_or_else(|| Box::new(EmptyAsyncRead));339 let mut ol = (!want_stdout)340 .then(|| out.take().unwrap())341 .unwrap_or_else(|| Box::new(EmptyAsyncRead));342 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());343 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());344345 346347 let mut out_buf = if want_stdout { Some(vec![]) } else { None };348 loop {349 select! {350 e = err.next() => {351 if let Some(e) = e {352 let e = e?;353 err_handler.handle_line(&e);354 }355 },356 o = ob.next() => {357 if let Some(o) = o {358 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);359 }360 },361 o = ol.next() => {362 if let Some(o) = o {363 let o = o?;364 if let Some(out) = out_handler.as_mut() {365 out.handle_line(&o)366 } else {367 err_handler.handle_line(&o)368 }369 370 }371 },372 code = child.wait() => {373 let code = code?;374 if !code.success() {375 anyhow::bail!("command '{str}' failed with status {}", code);376 }377 break;378 }379 }380 }381382 Ok(out_buf)383}384async fn run_nix_inner_raw_ssh(385 str: String,386 mut cmd: OwningCommand<Arc<Session>>,387 want_stdout: bool,388 err_handler: &mut dyn Handler,389 mut out_handler: Option<&mut dyn Handler>,390) -> Result<Option<Vec<u8>>> {391 debug!("running command {str:?} over ssh");392 cmd.stderr(openssh::Stdio::piped());393 cmd.stdout(openssh::Stdio::piped());394 let mut child = cmd.spawn().await?;395 let mut stderr = child.stderr().take().unwrap();396 let stdout = child.stdout().take().unwrap();397 let mut err = FramedRead::new(&mut stderr, LinesCodec::new());398 let mut out: Option<Box<dyn AsyncRead + Unpin>> = Some(Box::new(stdout));399 let mut ob = want_stdout400 .then(|| out.take().unwrap())401 .unwrap_or_else(|| Box::new(EmptyAsyncRead));402 let mut ol = (!want_stdout)403 .then(|| out.take().unwrap())404 .unwrap_or_else(|| Box::new(EmptyAsyncRead));405 let mut ob = FramedRead::new(&mut ob, BytesCodec::new());406 let mut ol = FramedRead::new(&mut ol, LinesCodec::new());407408 409410 let mut out_buf = if want_stdout { Some(vec![]) } else { None };411412 let mut wait_future = pin::pin!(child.wait());413 loop {414 select! {415 e = err.next() => {416 if let Some(e) = e {417 let e = e?;418 err_handler.handle_line(&e);419 }420 },421 o = ob.next() => {422 if let Some(o) = o {423 out_buf.as_mut().expect("stdout == wants_stdout").extend_from_slice(&o?);424 }425 },426 o = ol.next() => {427 if let Some(o) = o {428 let o = o?;429 if let Some(out) = out_handler.as_mut() {430 out.handle_line(&o)431 } else {432 err_handler.handle_line(&o)433 }434 435 }436 },437 code = &mut wait_future => {438 let code = code?;439 if !code.success() {440 anyhow::bail!("command '{str}' failed with status {}", code);441 }442 break;443 }444 }445 }446447 Ok(out_buf)448}