difftreelog
fix handle nix repl errors without extra synchronization
in: trunk
1 file changed
cmds/fleet/src/better_nix_eval.rsdiffbeforeafterboth1use std::ffi::{OsStr, OsString};1use std::ffi::{OsStr, OsString};2use std::fmt::Display;2use std::process::Stdio;3use std::process::Stdio;3use std::sync::{Arc, Mutex, OnceLock};4use std::sync::{Arc, OnceLock};455use abort_on_drop::ChildTask;6use anyhow::{anyhow, bail, ensure, Context, Result};6use anyhow::{anyhow, bail, ensure, Context, Result};7use futures::StreamExt;7use futures::StreamExt;8use itertools::Itertools;8use r2d2::{Pool, PooledConnection};9use r2d2::{Pool, PooledConnection};9use serde::de::DeserializeOwned;10use serde::de::DeserializeOwned;10use serde::Deserialize;11use serde::Deserialize;11use tokio::io::AsyncWriteExt;12use tokio::io::AsyncWriteExt;12use tokio::process::{ChildStdin, ChildStdout, Command};13use tokio::process::{ChildStderr, ChildStdin, ChildStdout, Command};13use tokio::sync::oneshot;14use tokio::select;15use tokio::sync::{mpsc, oneshot};14use tokio_util::codec::{FramedRead, LinesCodec};16use tokio_util::codec::{FramedRead, LinesCodec};15use tracing::debug;17use tracing::{debug, error, warn};161817use crate::command::{process_child_stderr, ErrorRecorder, ErrorRecorderT, NixHandler};19use crate::command::{ClonableHandler, Handler, NixHandler, NoopHandler};182019const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";21const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";20// To synchronize stderr and stdout. It works, yet I hate this.21// There is no other way to catch errors, because they are coming from different streams, and they are not synchronized in tokio.22const ERROR_DELIMITER: &str = "FLEET_MAGIC_ERROR_DELIMITER";232224pub struct NixSessionInner {23pub struct NixSessionInner {25 full_delimiter: String,24 full_delimiter: String,26 #[allow(dead_code)]25 nix_handler: ClonableHandler<NixHandler>,27 stderr_handler: ChildTask<Result<()>>,28 error_recorder: ErrorRecorderT,26 out: OutputHandler,29 read: FramedRead<ChildStdout, LinesCodec>,30 stdin: ChildStdin,27 stdin: ChildStdin,31 string_wrapping: (String, String),28 string_wrapping: (String, String),32 number_wrapping: (String, String),29 number_wrapping: (String, String),33 error_delimiter: String,343035 next_id: u32,31 next_id: u32,36 free_list: Vec<u32>,32 free_list: Vec<u32>,37}33}38const TRAIN_STRING: &str = "\"TRAIN_STRING\"";34const TRAIN_STRING: &str = "\"TRAIN_STRING\"";39const TRAIN_NUMBER: &str = "13141516";35const TRAIN_NUMBER: &str = "13141516";403637#[must_use]41struct ErrorRecorderHandle {38struct ErrorCollector<'i, H> {42 handle: ErrorRecorderT,39 collected: Vec<String>,40 inner: &'i mut H,43}41}44impl ErrorRecorderHandle {}42impl<'i, H> ErrorCollector<'i, H> {45impl Drop for ErrorRecorderHandle {46 fn drop(&mut self) {43 fn new(inner: &'i mut H) -> Self {47 let mut recorded = self.handle.lock().unwrap();44 Self {48 assert!(recorded.is_some(), "exclusive");45 collected: vec![],49 *recorded = None;46 inner,47 }50 }48 }51}49}5253struct ErrorCollector {50impl<H> ErrorCollector<'_, H> {54 collected: Arc<Mutex<Vec<String>>>,55 delim: String,56 got_delim: Option<oneshot::Sender<()>>,57}58impl ErrorRecorder for ErrorCollector {59 fn push_message(&mut self, msg: &str) -> bool {51 fn handle_line_inner(&mut self, msg: &str) -> bool {60 if msg == self.delim {61 let _ = self.got_delim62 .take()63 .expect("error delim is only expected once")64 .send(());65 return true;66 }67 let Some(msg) = msg.strip_prefix("@nix ") else {52 let Some(msg) = msg.strip_prefix("@nix ") else {68 return false;53 return false;69 };54 };79 if act.action != "msg" || act.level != 0 {64 if act.action != "msg" || act.level != 0 {80 return false;65 return false;81 }66 }82 self.collected.lock().unwrap().push(act.msg);67 self.collected.push(act.msg);83 true68 true84 }69 }70 fn finish(self) -> Result<()> {71 // fn dedent(s: String) -> String {72 // s.split('\n').filter(|s| !s.trim().is_empty()).map(|v| v.)73 // }74 if !self.collected.is_empty() {75 bail!("{}", self.collected.iter().map(|v| {76 if let Some(f) = v.strip_prefix("\u{1b}[31;1merror:\u{1b}[0m ") {77 let v = unindent::unindent(f.trim_start());78 v.trim().to_owned()79 } else {80 v.to_owned()81 }82 }).join("\n"));83 }84 Ok(())85 }86 fn flush(self) {87 for line in self.collected {88 warn!("{line}");89 }90 }85}91}92impl<H: Handler> Handler for ErrorCollector<'_, H> {93 fn handle_line(&mut self, e: &str) {94 if self.handle_line_inner(e) {95 return;96 }97 self.inner.handle_line(e)98 }99}86100101enum OutputLine {102 Out(String),103 Err(String),104}105struct OutputHandler {106 rx: mpsc::Receiver<OutputLine>,107 _cancel_handle: oneshot::Receiver<()>,108}109impl OutputHandler {110 fn new(out: ChildStdout, err: ChildStderr) -> Self {111 let mut out = FramedRead::new(out, LinesCodec::new());112 let mut err = FramedRead::new(err, LinesCodec::new());113 let (tx, rx) = mpsc::channel(20);114 let (mut cancelled, _cancel_handle) = oneshot::channel();115 tokio::spawn(async move {116 loop {117 select! {118 // We should receive errors earlier than synchronization119 biased;120 e = err.next() => {121 let Some(Ok(e)) = e else {122 if e.is_some() {123 error!("bad repl stderr: {e:?}");124 }125 continue;126 };127 let _ = tx.send(OutputLine::Err(e)).await;128 }129 o = out.next() => {130 let Some(Ok(o)) = o else {131 if o.is_some() {132 error!("bad repl stdout: {o:?}");133 }134 continue;135 };136 let _ = tx.send(OutputLine::Out(o)).await;137 }138 // Reader doesn't care about stdout, as this is cancelled.139 // Error still might be useful, to process leftover span closures?140 _ = cancelled.closed() => {141 break;142 }143 }144 }145 });146 Self { rx, _cancel_handle }147 }148 async fn next(&mut self) -> Option<OutputLine> {149 self.rx.recv().await150 }151}15287impl NixSessionInner {153impl NixSessionInner {88 async fn new(flake: &OsStr, extra_args: impl IntoIterator<Item = &OsStr>) -> Result<Self> {154 async fn new(flake: &OsStr, extra_args: impl IntoIterator<Item = &OsStr>) -> Result<Self> {89 let mut cmd = Command::new("nix");155 let mut cmd = Command::new("nix");100 let cmd = cmd.spawn()?;166 let cmd = cmd.spawn()?;101 let stdout = cmd.stdout.unwrap();167 let stdout = cmd.stdout.unwrap();102 let stderr = cmd.stderr.unwrap();168 let stderr = cmd.stderr.unwrap();169 let mut out = OutputHandler::new(stdout, stderr);103 let mut stdin = cmd.stdin.unwrap();170 let mut stdin = cmd.stdin.unwrap();104 let error_recorder = ErrorRecorderT::default();105 let err_recorder = error_recorder.clone();106 let stderr_handler = abort_on_drop::ChildTask::from(tokio::spawn(async move {107 let mut handler = NixHandler::default();108 process_child_stderr(stderr, &mut handler, err_recorder).await109 }));110 // Standard repl hello doesn't work with internal-json logger171 // Standard repl hello doesn't work with internal-json logger111 stdin.write_all(REPL_DELIMITER.as_bytes()).await?;172 stdin.write_all(REPL_DELIMITER.as_bytes()).await?;112 stdin.write_all(b"\n").await?;173 stdin.write_all(b"\n").await?;113 stdin.flush().await?;174 stdin.flush().await?;114 let mut read = FramedRead::new(stdout, LinesCodec::new());175 let nix_handler = NixHandler::default();115 let mut full_delimiter = None;176 let mut full_delimiter = None;116 while let Some(line) = read.next().await {177 while let Some(line) = out.next().await {117 let line = line?;178 let line = match line {179 OutputLine::Out(o) => o,180 OutputLine::Err(_e) => {181 // Handle startup errors, but skip repl hello?182 //nix_handler.handle_line(&e);183 continue;184 }185 };118 if line.contains(REPL_DELIMITER) {186 if line.contains(REPL_DELIMITER) {119 debug!("discovered repl delimiter with added colors: {line}");187 debug!("discovered repl delimiter with added colors: {line}");120 full_delimiter = Some(line.to_owned());188 full_delimiter = Some(line.to_owned());126 };194 };127 let mut res = Self {195 let mut res = Self {128 full_delimiter,196 full_delimiter,129 error_delimiter: "[[filled after training]]".to_owned(),197 nix_handler: ClonableHandler::new(nix_handler),130 stderr_handler,131 error_recorder,132 read,198 out,133 stdin,199 stdin,134 string_wrapping: Default::default(),200 string_wrapping: Default::default(),135 number_wrapping: Default::default(),201 number_wrapping: Default::default(),142 }208 }143 async fn train(&mut self) -> Result<()> {209 async fn train(&mut self) -> Result<()> {144 {210 {145 let full_string = self.execute_expression_raw(TRAIN_STRING).await?;211 let full_string = self212 .execute_expression_raw(TRAIN_STRING, &mut NoopHandler)213 .await?;146 let string_offset = full_string.find(TRAIN_STRING).expect("contained");214 let string_offset = full_string.find(TRAIN_STRING).expect("contained");147 let string_prefix = &full_string[..string_offset];215 let string_prefix = &full_string[..string_offset];148 let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];216 let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];149 self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());217 self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());150 }218 }151 {219 {152 let full_number = self.execute_expression_raw(TRAIN_NUMBER).await?;220 let full_number = self221 .execute_expression_raw(TRAIN_NUMBER, &mut NoopHandler)222 .await?;153 let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");223 let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");154 let number_prefix = &full_number[..number_offset];224 let number_prefix = &full_number[..number_offset];155 let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];225 let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];156 self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());226 self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());157 }227 }158 {159 struct TrainingErrorCollector(Option<oneshot::Sender<String>>);160 impl ErrorRecorder for TrainingErrorCollector {161 fn push_message(&mut self, msg: &str) -> bool {162 if msg.contains(ERROR_DELIMITER) {163 let _ = self164 .0165 .take()166 .expect("error delimiter is sent once")167 .send(msg.to_owned());168 }169 true170 }171 }172 let (tx, rx) = oneshot::channel();173 let _handle = self.record_error(TrainingErrorCollector(Some(tx)));174 self.send_command(ERROR_DELIMITER).await?;175 self.send_command(REPL_DELIMITER).await?;176 self.read_until_delimiter().await?;177 let msg = rx.await?;178 self.error_delimiter = msg;179 }180 Ok(())228 Ok(())181 }229 }182 fn record_error(&mut self, v: impl ErrorRecorder + 'static) -> ErrorRecorderHandle {183 {184 let mut recorder = self.error_recorder.lock().unwrap();185 assert!(recorder.is_none(), "recorder is already started");186 *recorder = Some(Box::new(v));187 }188 ErrorRecorderHandle {189 handle: self.error_recorder.clone(),190 }191 }192 async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {230 async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {193 self.stdin.write_all(cmd.as_ref()).await?;231 self.stdin.write_all(cmd.as_ref()).await?;194 self.stdin.write_all(b"\n").await?;232 self.stdin.write_all(b"\n").await?;195 Ok(())233 Ok(())196 }234 }197 async fn read_until_delimiter(&mut self) -> Result<String> {235 async fn read_until_delimiter(&mut self, err_handler: &mut dyn Handler) -> Result<String> {198 let mut out = String::new();236 let mut out = String::new();199 while let Some(line) = self.read.next().await {237 while let Some(line) = self.out.next().await {200 let line = line?;238 let line = match line {239 OutputLine::Out(out) => out,240 OutputLine::Err(err) => {241 err_handler.handle_line(&err);242 continue;243 }244 };201 if line == self.full_delimiter {245 if line == self.full_delimiter {202 return Ok(out);246 return Ok(out);203 }247 }234 expr: impl AsRef<[u8]>,278 expr: impl AsRef<[u8]>,235 wrapping: &(String, String),279 wrapping: &(String, String),236 ) -> Result<String> {280 ) -> Result<String> {237 let collected = Arc::new(Mutex::new(vec![]));281 let mut nix_handler = self.nix_handler.clone();238 let (etx, erx) = oneshot::channel();282 let mut collected = ErrorCollector::new(&mut nix_handler);239 let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});240 let res = self.execute_expression_raw(expr).await?;283 let res = self.execute_expression_raw(expr, &mut collected).await?;241 let _ = self.execute_expression_raw(ERROR_DELIMITER).await?;242 let _ = erx.await;243 if res.is_empty() {284 if res.is_empty() {244 let c = collected.lock().unwrap();285 collected.finish()?;245 if c.is_empty() {286 bail!("expected expression, got nothing")246 bail!("expected expression, got nothing")247 }287 } else {248 bail!("{}", c.join("\n"));288 collected.flush()249 }289 };250 drop(_collector);251 let Some(res) = res.strip_prefix(&wrapping.0) else {290 let Some(res) = res.strip_prefix(&wrapping.0) else {252 bail!("invalid type")291 bail!("invalid type")253 };292 };257 Ok(res.to_owned())296 Ok(res.to_owned())258 }297 }259 async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {298 async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {260 let collected = Arc::new(Mutex::new(vec![]));299 let mut nix_handler = self.nix_handler.clone();261 let (etx, erx) = oneshot::channel();300 let mut collected = ErrorCollector::new(&mut nix_handler);262 let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});263 let v = self.execute_expression_raw(expr).await?;301 let v = self.execute_expression_raw(expr, &mut collected).await?;264 let _ = self.execute_expression_raw(ERROR_DELIMITER).await;265 let _ = erx.await;302 collected.finish()?;266267 let c = collected.lock().unwrap();268 if !c.is_empty() {269 bail!("{}", c.join("\n"));270 }271 ensure!(v.is_empty(), "unexpected expression result");303 ensure!(v.is_empty(), "unexpected expression result");272 Ok(())304 Ok(())273 }305 }274 async fn execute_expression_raw(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {306 async fn execute_expression_raw(307 &mut self,308 expr: impl AsRef<[u8]>,309 err_handler: &mut dyn Handler,310 ) -> Result<String> {275 self.send_command(expr).await?;311 self.send_command(expr).await?;276 // It will be echoed312 // It will be echoed277 self.send_command(REPL_DELIMITER).await?;313 self.send_command(REPL_DELIMITER).await?;278 self.read_until_delimiter().await314 self.read_until_delimiter(err_handler).await279 }315 }280 async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {316 async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {281 let id = self.allocate_id();317 let id = self.allocate_id();306#[derive(Clone)]342#[derive(Clone)]307pub struct NixSession(Arc<tokio::sync::Mutex<PooledConnection<NixSessionPoolInner>>>);343pub struct NixSession(Arc<tokio::sync::Mutex<PooledConnection<NixSessionPoolInner>>>);308344309#[derive(Clone, Debug)]345#[derive(Clone)]310enum Index {346enum Index {311 String(String),347 String(String),312 // Idx(u32),348 // Idx(u32),313}349}350impl Display for Index {351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {352 match self {353 Index::String(k) => {354 let v = nixlike::format_identifier(k.as_str());355 write!(f, ".{v}")356 }357 }358 }359}360struct PathDisplay<'i>(&'i [Index]);361impl Display for PathDisplay<'_> {362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {363 write!(f, "flake")?;364 for i in self.0 {365 write!(f, "{i}")?;366 }367 Ok(())368 }369}314pub struct Field {370pub struct Field {315 full_path: Vec<Index>,371 full_path: Vec<Index>,316 session: NixSession,372 session: NixSession,326 }382 }327 pub async fn field(session: NixSession, field: &str) -> Result<Self> {383 pub async fn field(session: NixSession, field: &str) -> Result<Self> {328 Self::root(session).get_field_deep([field]).await384 Self::root(session).get_field_deep([field]).await385 }386 pub async fn get_json_deep<'a, V: DeserializeOwned>(387 &self,388 name: impl IntoIterator<Item = &'a str>,389 ) -> Result<V> {390 let field = self.get_field_deep(name).await?;391 field.as_json().await329 }392 }330 pub async fn get_field(&self, name: &str) -> Result<Self> {393 pub async fn get_field(&self, name: &str) -> Result<Self> {331 self.get_field_deep([name]).await394 self.get_field_deep([name]).await364 .await427 .await365 .execute_assign(&query)428 .execute_assign(&query)366 .await429 .await367 .with_context(|| format!("full path: {:?}", full_path))?;430 .with_context(|| format!("full path: {}", PathDisplay(&full_path)))?;368 Ok(Self {431 Ok(Self {369 full_path,432 full_path,370 session: self.session.clone(),433 session: self.session.clone(),379 .await442 .await380 .execute_expression_to_json(&format!("sess_field_{id}"))443 .execute_expression_to_json(&format!("sess_field_{id}"))381 .await444 .await382 .with_context(|| format!("full path: {:?}", self.full_path))445 .with_context(|| format!("full path: {}", PathDisplay(&self.full_path)))383 }446 }384 pub async fn list_fields(&self) -> Result<Vec<String>> {447 pub async fn list_fields(&self) -> Result<Vec<String>> {385 let id = self.value.expect("can't list root fields");448 let id = self.value.expect("can't list root fields");389 .await452 .await390 .execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}"))453 .execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}"))391 .await454 .await392 .with_context(|| format!("full path: {:?}", self.full_path))455 .with_context(|| format!("full path: {}", PathDisplay(&self.full_path)))393 }456 }394}457}395impl Drop for Field {458impl Drop for Field {