1use std::{ffi::OsStr, num::ParseIntError, process::Stdio, sync::Arc};23use better_command::{ClonableHandler, Handler, NixHandler, NoopHandler};4use futures::StreamExt;5use itertools::Itertools as _;6use serde::{Deserialize, de::DeserializeOwned};7use thiserror::Error;8use tokio::{9 io::AsyncWriteExt,10 process::{ChildStderr, ChildStdin, ChildStdout, Command},11 select,12 sync::{Mutex, mpsc, oneshot},13};14use tokio_util::codec::{FramedRead, LinesCodec};15use tracing::{Level, debug, error, warn};1617#[derive(Error, Debug, Clone)]18pub enum Error {19 #[error("failed to create nix repl session: {0}")]20 SessionInit(&'static str),21 #[error("unexpected end of output, nix crashed?")]22 MissingDelimiter,2324 #[error("expression did'nt produce any output")]25 ExpectedOutput,26 #[error("expression produced output, which is unexpected")]27 UnexpectedOutput,2829 #[error("unexpected expression output type")]30 InvalidType,3132 #[error("failed to build attr {attribute}:\n{error}")]33 BuildFailed { attribute: String, error: String },3435 #[error("output: {0}")]36 Json(Arc<serde_json::Error>),37 38 39 #[error("int output: {0}")]40 Int(ParseIntError),41 #[error("pool: {0}")]42 Pool(Arc<r2d2::Error>),43 #[error("io: {0}")]44 Io(Arc<std::io::Error>),4546 47 #[error("at {0}: {1}")]48 InContext(String, Box<Self>),4950 #[error("error: {0}")]51 NixError(String),52}53impl From<r2d2::Error> for Error {54 fn from(value: r2d2::Error) -> Self {55 Self::Pool(Arc::new(value))56 }57}58impl From<std::io::Error> for Error {59 fn from(value: std::io::Error) -> Self {60 Self::Io(Arc::new(value))61 }62}63impl From<serde_json::Error> for Error {64 fn from(value: serde_json::Error) -> Self {65 Self::Json(Arc::new(value))66 }67}68impl Error {69 pub(crate) fn context(self, context: String) -> Self {70 Self::InContext(context, Box::new(self))71 }72}73pub type Result<T, E = Error> = std::result::Result<T, E>;7475enum OutputLine {76 Out(String),77 Err(String),78}79struct OutputHandler {80 rx: mpsc::Receiver<OutputLine>,81 _cancel_handle: oneshot::Receiver<()>,82}83impl OutputHandler {84 fn new(out: ChildStdout, err: ChildStderr) -> Self {85 let mut out = FramedRead::new(out, LinesCodec::new());86 let mut err = FramedRead::new(err, LinesCodec::new());87 let (tx, rx) = mpsc::channel(20);88 let (mut cancelled, _cancel_handle) = oneshot::channel();89 tokio::spawn(async move {90 loop {91 select! {92 93 biased;94 e = err.next() => {95 let Some(Ok(e)) = e else {96 if e.is_some() {97 error!("bad repl stderr: {e:?}");98 }99 continue;100 };101 let _ = tx.send(OutputLine::Err(e)).await;102 }103 o = out.next() => {104 let Some(Ok(o)) = o else {105 if o.is_some() {106 error!("bad repl stdout: {o:?}");107 }108 continue;109 };110 let _ = tx.send(OutputLine::Out(o)).await;111 }112 113 114 _ = cancelled.closed() => {115 break;116 }117 }118 }119 });120 Self { rx, _cancel_handle }121 }122 async fn next(&mut self) -> Option<OutputLine> {123 self.rx.recv().await124 }125}126127#[must_use]128struct ErrorCollector<'i, H> {129 collected: Vec<String>,130 inner: &'i mut H,131}132impl<'i, H> ErrorCollector<'i, H> {133 fn new(inner: &'i mut H) -> Self {134 Self {135 collected: vec![],136 inner,137 }138 }139}140impl<H> ErrorCollector<'_, H> {141 fn handle_line_inner(&mut self, msg: &str) -> bool {142 let Some(msg) = msg.strip_prefix("@nix ") else {143 return false;144 };145 #[derive(Deserialize)]146 struct ErrorAction {147 action: String,148 level: u32,149 msg: String,150 }151 let Ok(act) = serde_json::from_str::<ErrorAction>(msg) else {152 return false;153 };154 if act.action != "msg" || act.level != 0 {155 return false;156 }157 self.collected.push(act.msg);158 true159 }160 fn finish(self) -> Result<()> {161 162 163 164 if !self.collected.is_empty() {165 return Err(Error::NixError(166 self.collected167 .iter()168 .map(|v| {169 if let Some(f) = v.strip_prefix("\u{1b}[31;1merror:\u{1b}[0m ") {170 let v = unindent::unindent(f.trim_start());171 v.trim().to_owned()172 } else {173 v.to_owned()174 }175 })176 .join("\n")177 .to_string(),178 ));179 }180 Ok(())181 }182 fn flush(self) {183 for line in self.collected {184 warn!("{line}");185 }186 }187}188impl<H: Handler> Handler for ErrorCollector<'_, H> {189 fn handle_line(&mut self, e: &str) {190 if self.handle_line_inner(e) {191 return;192 }193 self.inner.handle_line(e)194 }195}196197pub struct NixSessionInner {198 full_delimiter: String,199 nix_handler: ClonableHandler<NixHandler>,200 out: OutputHandler,201 stdin: ChildStdin,202 string_wrapping: (String, String),203 number_wrapping: (String, String),204205 executing_command: Arc<Mutex<()>>,206207 next_id: u32,208 pub(crate) free_list: Vec<u32>,209210 pub nix_system: String,211}212213214const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";215216const TRAIN_STRING: &str = "\"TRAIN_STRING\"";217218const TRAIN_NUMBER: &str = "13141516";219220221222223impl NixSessionInner {224 pub(crate) async fn new(225 flake: &OsStr,226 extra_args: impl IntoIterator<Item = &OsStr>,227 nix_system: String,228 fail_fast: bool,229 ) -> Result<Self> {230 let mut cmd = Command::new("nix");231 cmd.arg("repl")232 .args(["--option", "pure-eval", "true"])233 .arg(flake)234 .arg("--log-format")235 .arg("internal-json");236 if !fail_fast {237 cmd.arg("--keep-going");238 }239 for arg in extra_args {240 cmd.arg(arg);241 }242 cmd.stdin(Stdio::piped());243 cmd.stdout(Stdio::piped());244 cmd.stderr(Stdio::piped());245 let cmd = cmd.spawn()?;246 let stdout = cmd.stdout.unwrap();247 let stderr = cmd.stderr.unwrap();248 let mut out = OutputHandler::new(stdout, stderr);249 let mut stdin = cmd.stdin.unwrap();250 251 stdin.write_all(REPL_DELIMITER.as_bytes()).await?;252 stdin.write_all(b"\n").await?;253 stdin.flush().await?;254 let nix_handler = NixHandler::default();255 let mut full_delimiter = None;256 let mut errors = vec![];257 while let Some(line) = out.next().await {258 let line = match line {259 OutputLine::Out(o) => o,260 OutputLine::Err(_e) => {261 262 errors.push(_e);263 continue;264 }265 };266 if line.contains(REPL_DELIMITER) {267 debug!("discovered repl delimiter with added colors: {line}");268 full_delimiter = Some(line.to_owned());269 break;270 }271 }272 let Some(full_delimiter) = full_delimiter else {273 for e in errors {274 error!("{e}");275 }276 return Err(Error::SessionInit("failed to discover delimiter"));277 };278 let mut res = Self {279 full_delimiter,280 nix_handler: ClonableHandler::new(nix_handler),281 out,282 stdin,283 string_wrapping: Default::default(),284 number_wrapping: Default::default(),285286 executing_command: Arc::new(Mutex::new(())),287288 next_id: 0,289 free_list: vec![],290291 nix_system,292 };293 res.train().await?;294 Ok(res)295 }296 async fn train(&mut self) -> Result<()> {297 {298 let full_string = self299 .execute_expression_raw(TRAIN_STRING, &mut NoopHandler)300 .await?;301 let string_offset = full_string.find(TRAIN_STRING).expect("contained");302 let string_prefix = &full_string[..string_offset];303 let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];304 self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());305 }306 {307 let full_number = self308 .execute_expression_raw(TRAIN_NUMBER, &mut NoopHandler)309 .await?;310 let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");311 let number_prefix = &full_number[..number_offset];312 let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];313 self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());314 }315 Ok(())316 }317 async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {318 if tracing::enabled!(Level::DEBUG) && cmd.as_ref() != REPL_DELIMITER.as_bytes() {319 let cmd_str = String::from_utf8_lossy(cmd.as_ref());320 tracing::debug!("{cmd_str}");321 };322 self.stdin.write_all(cmd.as_ref()).await?;323 self.stdin.write_all(b"\n").await?;324 Ok(())325 }326 async fn read_until_delimiter(&mut self, err_handler: &mut dyn Handler) -> Result<String> {327 let mut out = String::new();328 while let Some(line) = self.out.next().await {329 let line = match line {330 OutputLine::Out(out) => out,331 OutputLine::Err(err) => {332 err_handler.handle_line(&err);333 continue;334 }335 };336 if line == self.full_delimiter {337 return Ok(out);338 }339 if !out.is_empty() {340 out.push('\n');341 }342 out.push_str(&line);343 }344 Err(Error::MissingDelimiter)345 }346 pub(crate) async fn execute_expression_number(347 &mut self,348 expr: impl AsRef<[u8]>,349 ) -> Result<u64> {350 let num = self.number_wrapping.clone();351 let n = self.execute_expression_wrapping(expr, &num).await?;352 n.parse::<u64>().map_err(Error::Int)353 }354 async fn execute_expression_string(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {355 356 357 358 359 360 let regex = regex::Regex::new(r#"(?<prefix>[: {,\[]\\")\\\$"#).expect("fixup json");361362 let num = self.string_wrapping.clone();363 let n = self.execute_expression_wrapping(expr, &num).await?;364 let n = regex.replace_all(&n, "$prefix$$");365 let str: String = serde_json::from_str(&n)?;366 Ok(str)367 }368 pub(crate) async fn execute_expression_to_json<V: DeserializeOwned>(369 &mut self,370 expr: impl AsRef<[u8]>,371 ) -> Result<V> {372 let mut fexpr = b"builtins.toJSON (".to_vec();373 fexpr.extend_from_slice(expr.as_ref());374 fexpr.push(b')');375376 Ok(serde_json::from_str(377 &self.execute_expression_string(fexpr).await?,378 )?)379 }380 async fn execute_expression_wrapping(381 &mut self,382 expr: impl AsRef<[u8]>,383 wrapping: &(String, String),384 ) -> Result<String> {385 let mut nix_handler = self.nix_handler.clone();386 let mut collected = ErrorCollector::new(&mut nix_handler);387 let res = self.execute_expression_raw(expr, &mut collected).await?;388 if res.is_empty() {389 collected.finish()?;390 return Err(Error::ExpectedOutput);391 } else {392 collected.flush()393 };394 let Some(res) = res.strip_prefix(&wrapping.0) else {395 return Err(Error::InvalidType);396 };397 let Some(res) = res.strip_suffix(&wrapping.1) else {398 return Err(Error::InvalidType);399 };400 Ok(res.to_owned())401 }402 async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {403 let mut nix_handler = self.nix_handler.clone();404 let mut collected = ErrorCollector::new(&mut nix_handler);405 let v = self.execute_expression_raw(expr, &mut collected).await?;406 collected.finish()?;407 if !v.is_empty() {408 return Err(Error::UnexpectedOutput);409 }410 Ok(())411 }412 pub(crate) async fn execute_expression_raw(413 &mut self,414 expr: impl AsRef<[u8]>,415 err_handler: &mut dyn Handler,416 ) -> Result<String> {417 418 let _lock = self.executing_command.clone();419 let _guard = _lock.lock().await;420421 self.send_command(expr).await?;422 423 self.send_command(REPL_DELIMITER).await?;424 self.read_until_delimiter(err_handler).await425 }426 pub(crate) async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {427 let id = self.allocate_id();428 self.execute_expression_empty(format!("sess_field_{id} = {}", expr.as_ref()))429 .await?;430 Ok(id)431 }432433 434 fn allocate_id(&mut self) -> u32 {435 if let Some(free) = self.free_list.pop() {436 free437 } else {438 let v = self.next_id;439 self.next_id += 1;440 v441 }442 }443 444 445 446 447 448 449 450}