1use std::{ffi::OsStr, num::ParseIntError, process::Stdio, sync::Arc};23use better_command::{ClonableHandler, Handler, NixHandler, NoopHandler};4use futures::StreamExt;5use itertools::Itertools as _;6use serde::{de::DeserializeOwned, Deserialize};7use thiserror::Error;8use tokio::{9 io::AsyncWriteExt,10 process::{ChildStderr, ChildStdin, ChildStdout, Command},11 select,12 sync::{mpsc, oneshot, Mutex},13};14use tokio_util::codec::{FramedRead, LinesCodec};15use tracing::{debug, error, warn, Level};1617#[derive(Error, Debug)]18pub enum Error {19 #[error("failed to create nix repl session: {0}")]20 SessionInit(&'static str),21 #[error("unexpected end of output, nix crashed?")]22 MissingDelimiter,2324 #[error("expression did'nt produce any output")]25 ExpectedOutput,26 #[error("expression produced output, which is unexpected")]27 UnexpectedOutput,2829 #[error("unexpected expression output type")]30 InvalidType,3132 #[error("failed to build attr {attribute}:\n{error}")]33 BuildFailed { attribute: String, error: String },3435 #[error("output: {0}")]36 Json(#[from] serde_json::Error),37 38 39 #[error("int output: {0}")]40 Int(ParseIntError),41 #[error("pool: {0}")]42 Pool(#[from] r2d2::Error),43 #[error("io: {0}")]44 Io(#[from] std::io::Error),4546 47 #[error("at {0}: {1}")]48 InContext(String, Box<Self>),4950 #[error("error: {0}")]51 NixError(String),52}53impl Error {54 pub(crate) fn context(self, context: String) -> Self {55 Self::InContext(context, Box::new(self))56 }57}58pub type Result<T, E = Error> = std::result::Result<T, E>;5960enum OutputLine {61 Out(String),62 Err(String),63}64struct OutputHandler {65 rx: mpsc::Receiver<OutputLine>,66 _cancel_handle: oneshot::Receiver<()>,67}68impl OutputHandler {69 fn new(out: ChildStdout, err: ChildStderr) -> Self {70 let mut out = FramedRead::new(out, LinesCodec::new());71 let mut err = FramedRead::new(err, LinesCodec::new());72 let (tx, rx) = mpsc::channel(20);73 let (mut cancelled, _cancel_handle) = oneshot::channel();74 tokio::spawn(async move {75 loop {76 select! {77 78 biased;79 e = err.next() => {80 let Some(Ok(e)) = e else {81 if e.is_some() {82 error!("bad repl stderr: {e:?}");83 }84 continue;85 };86 let _ = tx.send(OutputLine::Err(e)).await;87 }88 o = out.next() => {89 let Some(Ok(o)) = o else {90 if o.is_some() {91 error!("bad repl stdout: {o:?}");92 }93 continue;94 };95 let _ = tx.send(OutputLine::Out(o)).await;96 }97 98 99 _ = cancelled.closed() => {100 break;101 }102 }103 }104 });105 Self { rx, _cancel_handle }106 }107 async fn next(&mut self) -> Option<OutputLine> {108 self.rx.recv().await109 }110}111112#[must_use]113struct ErrorCollector<'i, H> {114 collected: Vec<String>,115 inner: &'i mut H,116}117impl<'i, H> ErrorCollector<'i, H> {118 fn new(inner: &'i mut H) -> Self {119 Self {120 collected: vec![],121 inner,122 }123 }124}125impl<H> ErrorCollector<'_, H> {126 fn handle_line_inner(&mut self, msg: &str) -> bool {127 let Some(msg) = msg.strip_prefix("@nix ") else {128 return false;129 };130 #[derive(Deserialize)]131 struct ErrorAction {132 action: String,133 level: u32,134 msg: String,135 }136 let Ok(act) = serde_json::from_str::<ErrorAction>(msg) else {137 return false;138 };139 if act.action != "msg" || act.level != 0 {140 return false;141 }142 self.collected.push(act.msg);143 true144 }145 fn finish(self) -> Result<()> {146 147 148 149 if !self.collected.is_empty() {150 return Err(Error::NixError(format!(151 "{}",152 self.collected153 .iter()154 .map(|v| {155 if let Some(f) = v.strip_prefix("\u{1b}[31;1merror:\u{1b}[0m ") {156 let v = unindent::unindent(f.trim_start());157 v.trim().to_owned()158 } else {159 v.to_owned()160 }161 })162 .join("\n"),163 )));164 }165 Ok(())166 }167 fn flush(self) {168 for line in self.collected {169 warn!("{line}");170 }171 }172}173impl<H: Handler> Handler for ErrorCollector<'_, H> {174 fn handle_line(&mut self, e: &str) {175 if self.handle_line_inner(e) {176 return;177 }178 self.inner.handle_line(e)179 }180}181182pub struct NixSessionInner {183 full_delimiter: String,184 nix_handler: ClonableHandler<NixHandler>,185 out: OutputHandler,186 stdin: ChildStdin,187 string_wrapping: (String, String),188 number_wrapping: (String, String),189190 executing_command: Arc<Mutex<()>>,191192 next_id: u32,193 pub(crate) free_list: Vec<u32>,194}195196197const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";198199const TRAIN_STRING: &str = "\"TRAIN_STRING\"";200201const TRAIN_NUMBER: &str = "13141516";202203204205206impl NixSessionInner {207 pub(crate) async fn new(208 flake: &OsStr,209 extra_args: impl IntoIterator<Item = &OsStr>,210 ) -> Result<Self> {211 let mut cmd = Command::new("nix");212 cmd.arg("repl")213 .arg(flake)214 .arg("--log-format")215 .arg("internal-json");216 for arg in extra_args {217 cmd.arg(arg);218 }219 cmd.stdin(Stdio::piped());220 cmd.stdout(Stdio::piped());221 cmd.stderr(Stdio::piped());222 let cmd = cmd.spawn()?;223 let stdout = cmd.stdout.unwrap();224 let stderr = cmd.stderr.unwrap();225 let mut out = OutputHandler::new(stdout, stderr);226 let mut stdin = cmd.stdin.unwrap();227 228 stdin.write_all(REPL_DELIMITER.as_bytes()).await?;229 stdin.write_all(b"\n").await?;230 stdin.flush().await?;231 let nix_handler = NixHandler::default();232 let mut full_delimiter = None;233 let mut errors = vec![];234 while let Some(line) = out.next().await {235 let line = match line {236 OutputLine::Out(o) => o,237 OutputLine::Err(_e) => {238 239 errors.push(_e);240 continue;241 }242 };243 if line.contains(REPL_DELIMITER) {244 debug!("discovered repl delimiter with added colors: {line}");245 full_delimiter = Some(line.to_owned());246 break;247 }248 }249 let Some(full_delimiter) = full_delimiter else {250 for e in errors {251 error!("{e}");252 }253 return Err(Error::SessionInit("failed to discover delimiter"));254 };255 let mut res = Self {256 full_delimiter,257 nix_handler: ClonableHandler::new(nix_handler),258 out,259 stdin,260 string_wrapping: Default::default(),261 number_wrapping: Default::default(),262263 executing_command: Arc::new(Mutex::new(())),264265 next_id: 0,266 free_list: vec![],267 };268 res.train().await?;269 Ok(res)270 }271 async fn train(&mut self) -> Result<()> {272 {273 let full_string = self274 .execute_expression_raw(TRAIN_STRING, &mut NoopHandler)275 .await?;276 let string_offset = full_string.find(TRAIN_STRING).expect("contained");277 let string_prefix = &full_string[..string_offset];278 let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];279 self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());280 }281 {282 let full_number = self283 .execute_expression_raw(TRAIN_NUMBER, &mut NoopHandler)284 .await?;285 let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");286 let number_prefix = &full_number[..number_offset];287 let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];288 self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());289 }290 Ok(())291 }292 async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {293 if tracing::enabled!(Level::DEBUG) && cmd.as_ref() != REPL_DELIMITER.as_bytes() {294 let cmd_str = String::from_utf8_lossy(cmd.as_ref());295 tracing::debug!("{cmd_str}");296 };297 self.stdin.write_all(cmd.as_ref()).await?;298 self.stdin.write_all(b"\n").await?;299 Ok(())300 }301 async fn read_until_delimiter(&mut self, err_handler: &mut dyn Handler) -> Result<String> {302 let mut out = String::new();303 while let Some(line) = self.out.next().await {304 let line = match line {305 OutputLine::Out(out) => out,306 OutputLine::Err(err) => {307 err_handler.handle_line(&err);308 continue;309 }310 };311 if line == self.full_delimiter {312 return Ok(out);313 }314 if !out.is_empty() {315 out.push('\n');316 }317 out.push_str(&line);318 }319 return Err(Error::MissingDelimiter);320 }321 pub(crate) async fn execute_expression_number(322 &mut self,323 expr: impl AsRef<[u8]>,324 ) -> Result<u64> {325 let num = self.number_wrapping.clone();326 let n = self.execute_expression_wrapping(expr, &num).await?;327 n.parse::<u64>().map_err(Error::Int)328 }329 async fn execute_expression_string(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {330 let num = self.string_wrapping.clone();331 let n = self.execute_expression_wrapping(expr, &num).await?;332 let str: String = serde_json::from_str(&n)?;333 Ok(str)334 }335 pub(crate) async fn execute_expression_to_json<V: DeserializeOwned>(336 &mut self,337 expr: impl AsRef<[u8]>,338 ) -> Result<V> {339 let mut fexpr = b"builtins.toJSON (".to_vec();340 fexpr.extend_from_slice(expr.as_ref());341 fexpr.push(b')');342 let v = self.execute_expression_string(fexpr).await?;343 Ok(serde_json::from_str(&v)?)344 }345 async fn execute_expression_wrapping(346 &mut self,347 expr: impl AsRef<[u8]>,348 wrapping: &(String, String),349 ) -> Result<String> {350 let mut nix_handler = self.nix_handler.clone();351 let mut collected = ErrorCollector::new(&mut nix_handler);352 let res = self.execute_expression_raw(expr, &mut collected).await?;353 if res.is_empty() {354 collected.finish()?;355 return Err(Error::ExpectedOutput);356 } else {357 collected.flush()358 };359 let Some(res) = res.strip_prefix(&wrapping.0) else {360 return Err(Error::InvalidType);361 };362 let Some(res) = res.strip_suffix(&wrapping.1) else {363 return Err(Error::InvalidType);364 };365 Ok(res.to_owned())366 }367 async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {368 let mut nix_handler = self.nix_handler.clone();369 let mut collected = ErrorCollector::new(&mut nix_handler);370 let v = self.execute_expression_raw(expr, &mut collected).await?;371 collected.finish()?;372 if !v.is_empty() {373 return Err(Error::UnexpectedOutput);374 }375 Ok(())376 }377 pub(crate) async fn execute_expression_raw(378 &mut self,379 expr: impl AsRef<[u8]>,380 err_handler: &mut dyn Handler,381 ) -> Result<String> {382 383 let _lock = self.executing_command.clone();384 let _guard = _lock.lock().await;385386 self.send_command(expr).await?;387 388 self.send_command(REPL_DELIMITER).await?;389 self.read_until_delimiter(err_handler).await390 }391 pub(crate) async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {392 let id = self.allocate_id();393 self.execute_expression_empty(format!("sess_field_{id} = {}", expr.as_ref()))394 .await?;395 Ok(id)396 }397398 399 fn allocate_id(&mut self) -> u32 {400 if let Some(free) = self.free_list.pop() {401 free402 } else {403 let v = self.next_id;404 self.next_id += 1;405 v406 }407 }408 409 410 411 412 413 414 415}