--- a/cmds/fleet/src/better_nix_eval.rs +++ b/cmds/fleet/src/better_nix_eval.rs @@ -1,36 +1,32 @@ use std::ffi::{OsStr, OsString}; +use std::fmt::Display; use std::process::Stdio; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, OnceLock}; -use abort_on_drop::ChildTask; use anyhow::{anyhow, bail, ensure, Context, Result}; use futures::StreamExt; +use itertools::Itertools; use r2d2::{Pool, PooledConnection}; use serde::de::DeserializeOwned; use serde::Deserialize; use tokio::io::AsyncWriteExt; -use tokio::process::{ChildStdin, ChildStdout, Command}; -use tokio::sync::oneshot; +use tokio::process::{ChildStderr, ChildStdin, ChildStdout, Command}; +use tokio::select; +use tokio::sync::{mpsc, oneshot}; use tokio_util::codec::{FramedRead, LinesCodec}; -use tracing::debug; +use tracing::{debug, error, warn}; -use crate::command::{process_child_stderr, ErrorRecorder, ErrorRecorderT, NixHandler}; +use crate::command::{ClonableHandler, Handler, NixHandler, NoopHandler}; const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\""; -// To synchronize stderr and stdout. It works, yet I hate this. -// There is no other way to catch errors, because they are coming from different streams, and they are not synchronized in tokio. -const ERROR_DELIMITER: &str = "FLEET_MAGIC_ERROR_DELIMITER"; pub struct NixSessionInner { full_delimiter: String, - #[allow(dead_code)] - stderr_handler: ChildTask>, - error_recorder: ErrorRecorderT, - read: FramedRead, + nix_handler: ClonableHandler, + out: OutputHandler, stdin: ChildStdin, string_wrapping: (String, String), number_wrapping: (String, String), - error_delimiter: String, next_id: u32, free_list: Vec, @@ -38,32 +34,21 @@ const TRAIN_STRING: &str = "\"TRAIN_STRING\""; const TRAIN_NUMBER: &str = "13141516"; -struct ErrorRecorderHandle { - handle: ErrorRecorderT, +#[must_use] +struct ErrorCollector<'i, H> { + collected: Vec, + inner: &'i mut H, } -impl ErrorRecorderHandle {} -impl Drop for ErrorRecorderHandle { - fn drop(&mut self) { - let mut recorded = self.handle.lock().unwrap(); - assert!(recorded.is_some(), "exclusive"); - *recorded = None; +impl<'i, H> ErrorCollector<'i, H> { + fn new(inner: &'i mut H) -> Self { + Self { + collected: vec![], + inner, + } } -} - -struct ErrorCollector { - collected: Arc>>, - delim: String, - got_delim: Option>, } -impl ErrorRecorder for ErrorCollector { - fn push_message(&mut self, msg: &str) -> bool { - if msg == self.delim { - let _ = self.got_delim - .take() - .expect("error delim is only expected once") - .send(()); - return true; - } +impl ErrorCollector<'_, H> { + fn handle_line_inner(&mut self, msg: &str) -> bool { let Some(msg) = msg.strip_prefix("@nix ") else { return false; }; @@ -79,11 +64,92 @@ if act.action != "msg" || act.level != 0 { return false; } - self.collected.lock().unwrap().push(act.msg); + self.collected.push(act.msg); true } + fn finish(self) -> Result<()> { + // fn dedent(s: String) -> String { + // s.split('\n').filter(|s| !s.trim().is_empty()).map(|v| v.) + // } + if !self.collected.is_empty() { + bail!("{}", self.collected.iter().map(|v| { + if let Some(f) = v.strip_prefix("\u{1b}[31;1merror:\u{1b}[0m ") { + let v = unindent::unindent(f.trim_start()); + v.trim().to_owned() + } else { + v.to_owned() + } + }).join("\n")); + } + Ok(()) + } + fn flush(self) { + for line in self.collected { + warn!("{line}"); + } + } } +impl Handler for ErrorCollector<'_, H> { + fn handle_line(&mut self, e: &str) { + if self.handle_line_inner(e) { + return; + } + self.inner.handle_line(e) + } +} +enum OutputLine { + Out(String), + Err(String), +} +struct OutputHandler { + rx: mpsc::Receiver, + _cancel_handle: oneshot::Receiver<()>, +} +impl OutputHandler { + fn new(out: ChildStdout, err: ChildStderr) -> Self { + let mut out = FramedRead::new(out, LinesCodec::new()); + let mut err = FramedRead::new(err, LinesCodec::new()); + let (tx, rx) = mpsc::channel(20); + let (mut cancelled, _cancel_handle) = oneshot::channel(); + tokio::spawn(async move { + loop { + select! { + // We should receive errors earlier than synchronization + biased; + e = err.next() => { + let Some(Ok(e)) = e else { + if e.is_some() { + error!("bad repl stderr: {e:?}"); + } + continue; + }; + let _ = tx.send(OutputLine::Err(e)).await; + } + o = out.next() => { + let Some(Ok(o)) = o else { + if o.is_some() { + error!("bad repl stdout: {o:?}"); + } + continue; + }; + let _ = tx.send(OutputLine::Out(o)).await; + } + // Reader doesn't care about stdout, as this is cancelled. + // Error still might be useful, to process leftover span closures? + _ = cancelled.closed() => { + break; + } + } + } + }); + Self { rx, _cancel_handle } + } + async fn next(&mut self) -> Option { + self.rx.recv().await + } +} + impl NixSessionInner { async fn new(flake: &OsStr, extra_args: impl IntoIterator) -> Result { let mut cmd = Command::new("nix"); @@ -100,21 +166,23 @@ let cmd = cmd.spawn()?; let stdout = cmd.stdout.unwrap(); let stderr = cmd.stderr.unwrap(); + let mut out = OutputHandler::new(stdout, stderr); let mut stdin = cmd.stdin.unwrap(); - let error_recorder = ErrorRecorderT::default(); - let err_recorder = error_recorder.clone(); - let stderr_handler = abort_on_drop::ChildTask::from(tokio::spawn(async move { - let mut handler = NixHandler::default(); - process_child_stderr(stderr, &mut handler, err_recorder).await - })); // Standard repl hello doesn't work with internal-json logger stdin.write_all(REPL_DELIMITER.as_bytes()).await?; stdin.write_all(b"\n").await?; stdin.flush().await?; - let mut read = FramedRead::new(stdout, LinesCodec::new()); + let nix_handler = NixHandler::default(); let mut full_delimiter = None; - while let Some(line) = read.next().await { - let line = line?; + while let Some(line) = out.next().await { + let line = match line { + OutputLine::Out(o) => o, + OutputLine::Err(_e) => { + // Handle startup errors, but skip repl hello? + //nix_handler.handle_line(&e); + continue; + } + }; if line.contains(REPL_DELIMITER) { debug!("discovered repl delimiter with added colors: {line}"); full_delimiter = Some(line.to_owned()); @@ -126,10 +194,8 @@ }; let mut res = Self { full_delimiter, - error_delimiter: "[[filled after training]]".to_owned(), - stderr_handler, - error_recorder, - read, + nix_handler: ClonableHandler::new(nix_handler), + out, stdin, string_wrapping: Default::default(), number_wrapping: Default::default(), @@ -142,62 +208,40 @@ } async fn train(&mut self) -> Result<()> { { - let full_string = self.execute_expression_raw(TRAIN_STRING).await?; + let full_string = self + .execute_expression_raw(TRAIN_STRING, &mut NoopHandler) + .await?; let string_offset = full_string.find(TRAIN_STRING).expect("contained"); let string_prefix = &full_string[..string_offset]; let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..]; self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned()); } { - let full_number = self.execute_expression_raw(TRAIN_NUMBER).await?; + let full_number = self + .execute_expression_raw(TRAIN_NUMBER, &mut NoopHandler) + .await?; let number_offset = full_number.find(TRAIN_NUMBER).expect("contained"); let number_prefix = &full_number[..number_offset]; let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..]; self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned()); } - { - struct TrainingErrorCollector(Option>); - impl ErrorRecorder for TrainingErrorCollector { - fn push_message(&mut self, msg: &str) -> bool { - if msg.contains(ERROR_DELIMITER) { - let _ = self - .0 - .take() - .expect("error delimiter is sent once") - .send(msg.to_owned()); - } - true - } - } - let (tx, rx) = oneshot::channel(); - let _handle = self.record_error(TrainingErrorCollector(Some(tx))); - self.send_command(ERROR_DELIMITER).await?; - self.send_command(REPL_DELIMITER).await?; - self.read_until_delimiter().await?; - let msg = rx.await?; - self.error_delimiter = msg; - } Ok(()) } - fn record_error(&mut self, v: impl ErrorRecorder + 'static) -> ErrorRecorderHandle { - { - let mut recorder = self.error_recorder.lock().unwrap(); - assert!(recorder.is_none(), "recorder is already started"); - *recorder = Some(Box::new(v)); - } - ErrorRecorderHandle { - handle: self.error_recorder.clone(), - } - } async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> { self.stdin.write_all(cmd.as_ref()).await?; self.stdin.write_all(b"\n").await?; Ok(()) } - async fn read_until_delimiter(&mut self) -> Result { + async fn read_until_delimiter(&mut self, err_handler: &mut dyn Handler) -> Result { let mut out = String::new(); - while let Some(line) = self.read.next().await { - let line = line?; + while let Some(line) = self.out.next().await { + let line = match line { + OutputLine::Out(out) => out, + OutputLine::Err(err) => { + err_handler.handle_line(&err); + continue; + } + }; if line == self.full_delimiter { return Ok(out); } @@ -234,20 +278,15 @@ expr: impl AsRef<[u8]>, wrapping: &(String, String), ) -> Result { - let collected = Arc::new(Mutex::new(vec![])); - let (etx, erx) = oneshot::channel(); - let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)}); - let res = self.execute_expression_raw(expr).await?; - let _ = self.execute_expression_raw(ERROR_DELIMITER).await?; - let _ = erx.await; + let mut nix_handler = self.nix_handler.clone(); + let mut collected = ErrorCollector::new(&mut nix_handler); + let res = self.execute_expression_raw(expr, &mut collected).await?; if res.is_empty() { - let c = collected.lock().unwrap(); - if c.is_empty() { - bail!("expected expression, got nothing") - } - bail!("{}", c.join("\n")); - } - drop(_collector); + collected.finish()?; + bail!("expected expression, got nothing") + } else { + collected.flush() + }; let Some(res) = res.strip_prefix(&wrapping.0) else { bail!("invalid type") }; @@ -257,25 +296,22 @@ Ok(res.to_owned()) } async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> { - let collected = Arc::new(Mutex::new(vec![])); - let (etx, erx) = oneshot::channel(); - let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)}); - let v = self.execute_expression_raw(expr).await?; - let _ = self.execute_expression_raw(ERROR_DELIMITER).await; - let _ = erx.await; - - let c = collected.lock().unwrap(); - if !c.is_empty() { - bail!("{}", c.join("\n")); - } + let mut nix_handler = self.nix_handler.clone(); + let mut collected = ErrorCollector::new(&mut nix_handler); + let v = self.execute_expression_raw(expr, &mut collected).await?; + collected.finish()?; ensure!(v.is_empty(), "unexpected expression result"); Ok(()) } - async fn execute_expression_raw(&mut self, expr: impl AsRef<[u8]>) -> Result { + async fn execute_expression_raw( + &mut self, + expr: impl AsRef<[u8]>, + err_handler: &mut dyn Handler, + ) -> Result { self.send_command(expr).await?; // It will be echoed self.send_command(REPL_DELIMITER).await?; - self.read_until_delimiter().await + self.read_until_delimiter(err_handler).await } async fn execute_assign(&mut self, expr: impl AsRef) -> Result { let id = self.allocate_id(); @@ -306,11 +342,31 @@ #[derive(Clone)] pub struct NixSession(Arc>>); -#[derive(Clone, Debug)] +#[derive(Clone)] enum Index { String(String), // Idx(u32), } +impl Display for Index { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Index::String(k) => { + let v = nixlike::format_identifier(k.as_str()); + write!(f, ".{v}") + } + } + } +} +struct PathDisplay<'i>(&'i [Index]); +impl Display for PathDisplay<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "flake")?; + for i in self.0 { + write!(f, "{i}")?; + } + Ok(()) + } +} pub struct Field { full_path: Vec, session: NixSession, @@ -327,6 +383,13 @@ pub async fn field(session: NixSession, field: &str) -> Result { Self::root(session).get_field_deep([field]).await } + pub async fn get_json_deep<'a, V: DeserializeOwned>( + &self, + name: impl IntoIterator, + ) -> Result { + let field = self.get_field_deep(name).await?; + field.as_json().await + } pub async fn get_field(&self, name: &str) -> Result { self.get_field_deep([name]).await } @@ -364,7 +427,7 @@ .await .execute_assign(&query) .await - .with_context(|| format!("full path: {:?}", full_path))?; + .with_context(|| format!("full path: {}", PathDisplay(&full_path)))?; Ok(Self { full_path, session: self.session.clone(), @@ -379,7 +442,7 @@ .await .execute_expression_to_json(&format!("sess_field_{id}")) .await - .with_context(|| format!("full path: {:?}", self.full_path)) + .with_context(|| format!("full path: {}", PathDisplay(&self.full_path))) } pub async fn list_fields(&self) -> Result> { let id = self.value.expect("can't list root fields"); @@ -389,7 +452,7 @@ .await .execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}")) .await - .with_context(|| format!("full path: {:?}", self.full_path)) + .with_context(|| format!("full path: {}", PathDisplay(&self.full_path))) } } impl Drop for Field {