1use std::ffi::{OsStr, OsString};2use std::process::Stdio;3use std::sync::{Arc, Mutex, OnceLock};45use abort_on_drop::ChildTask;6use anyhow::{anyhow, bail, ensure, Context, Result};7use futures::StreamExt;8use r2d2::{Pool, PooledConnection};9use serde::de::DeserializeOwned;10use serde::Deserialize;11use tokio::io::AsyncWriteExt;12use tokio::process::{ChildStdin, ChildStdout, Command};13use tokio::sync::oneshot;14use tokio_util::codec::{FramedRead, LinesCodec};15use tracing::debug;1617use crate::command::{process_child_stderr, ErrorRecorder, ErrorRecorderT, NixHandler};1819const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";202122const ERROR_DELIMITER: &str = "FLEET_MAGIC_ERROR_DELIMITER";2324pub struct NixSessionInner {25 full_delimiter: String,26 #[allow(dead_code)]27 stderr_handler: ChildTask<Result<()>>,28 error_recorder: ErrorRecorderT,29 read: FramedRead<ChildStdout, LinesCodec>,30 stdin: ChildStdin,31 string_wrapping: (String, String),32 number_wrapping: (String, String),33 error_delimiter: String,3435 next_id: u32,36 free_list: Vec<u32>,37}38const TRAIN_STRING: &str = "\"TRAIN_STRING\"";39const TRAIN_NUMBER: &str = "13141516";4041struct ErrorRecorderHandle {42 handle: ErrorRecorderT,43}44impl ErrorRecorderHandle {}45impl Drop for ErrorRecorderHandle {46 fn drop(&mut self) {47 let mut recorded = self.handle.lock().unwrap();48 assert!(recorded.is_some(), "exclusive");49 *recorded = None;50 }51}5253struct ErrorCollector {54 collected: Arc<Mutex<Vec<String>>>,55 delim: String,56 got_delim: Option<oneshot::Sender<()>>,57}58impl ErrorRecorder for ErrorCollector {59 fn push_message(&mut self, msg: &str) -> bool {60 if msg == self.delim {61 let _ = self.got_delim62 .take()63 .expect("error delim is only expected once")64 .send(());65 return true;66 }67 let Some(msg) = msg.strip_prefix("@nix ") else {68 return false;69 };70 #[derive(Deserialize)]71 struct ErrorAction {72 action: String,73 level: u32,74 msg: String,75 }76 let Ok(act) = serde_json::from_str::<ErrorAction>(msg) else {77 return false;78 };79 if act.action != "msg" || act.level != 0 {80 return false;81 }82 self.collected.lock().unwrap().push(act.msg);83 true84 }85}8687impl NixSessionInner {88 async fn new(flake: &OsStr, extra_args: impl IntoIterator<Item = &OsStr>) -> Result<Self> {89 let mut cmd = Command::new("nix");90 cmd.arg("repl")91 .arg(flake)92 .arg("--log-format")93 .arg("internal-json");94 for arg in extra_args {95 cmd.arg(arg);96 }97 cmd.stdin(Stdio::piped());98 cmd.stdout(Stdio::piped());99 cmd.stderr(Stdio::piped());100 let cmd = cmd.spawn()?;101 let stdout = cmd.stdout.unwrap();102 let stderr = cmd.stderr.unwrap();103 let mut stdin = cmd.stdin.unwrap();104 let error_recorder = ErrorRecorderT::default();105 let err_recorder = error_recorder.clone();106 let stderr_handler = abort_on_drop::ChildTask::from(tokio::spawn(async move {107 let mut handler = NixHandler::default();108 process_child_stderr(stderr, &mut handler, err_recorder).await109 }));110 111 stdin.write_all(REPL_DELIMITER.as_bytes()).await?;112 stdin.write_all(b"\n").await?;113 stdin.flush().await?;114 let mut read = FramedRead::new(stdout, LinesCodec::new());115 let mut full_delimiter = None;116 while let Some(line) = read.next().await {117 let line = line?;118 if line.contains(REPL_DELIMITER) {119 debug!("discovered repl delimiter with added colors: {line}");120 full_delimiter = Some(line.to_owned());121 break;122 }123 }124 let Some(full_delimiter) = full_delimiter else {125 bail!("failed to discover delimiter");126 };127 let mut res = Self {128 full_delimiter,129 error_delimiter: "[[filled after training]]".to_owned(),130 stderr_handler,131 error_recorder,132 read,133 stdin,134 string_wrapping: Default::default(),135 number_wrapping: Default::default(),136137 next_id: 0,138 free_list: vec![],139 };140 res.train().await?;141 Ok(res)142 }143 async fn train(&mut self) -> Result<()> {144 {145 let full_string = self.execute_expression_raw(TRAIN_STRING).await?;146 let string_offset = full_string.find(TRAIN_STRING).expect("contained");147 let string_prefix = &full_string[..string_offset];148 let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];149 self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());150 }151 {152 let full_number = self.execute_expression_raw(TRAIN_NUMBER).await?;153 let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");154 let number_prefix = &full_number[..number_offset];155 let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];156 self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());157 }158 {159 struct TrainingErrorCollector(Option<oneshot::Sender<String>>);160 impl ErrorRecorder for TrainingErrorCollector {161 fn push_message(&mut self, msg: &str) -> bool {162 if msg.contains(ERROR_DELIMITER) {163 let _ = self164 .0165 .take()166 .expect("error delimiter is sent once")167 .send(msg.to_owned());168 }169 true170 }171 }172 let (tx, rx) = oneshot::channel();173 let _handle = self.record_error(TrainingErrorCollector(Some(tx)));174 self.send_command(ERROR_DELIMITER).await?;175 self.send_command(REPL_DELIMITER).await?;176 self.read_until_delimiter().await?;177 let msg = rx.await?;178 self.error_delimiter = msg;179 }180 Ok(())181 }182 fn record_error(&mut self, v: impl ErrorRecorder + 'static) -> ErrorRecorderHandle {183 {184 let mut recorder = self.error_recorder.lock().unwrap();185 assert!(recorder.is_none(), "recorder is already started");186 *recorder = Some(Box::new(v));187 }188 ErrorRecorderHandle {189 handle: self.error_recorder.clone(),190 }191 }192 async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {193 self.stdin.write_all(cmd.as_ref()).await?;194 self.stdin.write_all(b"\n").await?;195 Ok(())196 }197 async fn read_until_delimiter(&mut self) -> Result<String> {198 let mut out = String::new();199 while let Some(line) = self.read.next().await {200 let line = line?;201 if line == self.full_delimiter {202 return Ok(out);203 }204 if !out.is_empty() {205 out.push('\n');206 }207 out.push_str(&line);208 }209 bail!("didn't reached delimiter");210 }211 async fn execute_expression_number(&mut self, expr: impl AsRef<[u8]>) -> Result<u64> {212 let num = self.number_wrapping.clone();213 let n = self.execute_expression_wrapping(expr, &num).await?;214 Ok(n.parse::<u64>()?)215 }216 async fn execute_expression_string(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {217 let num = self.string_wrapping.clone();218 let n = self.execute_expression_wrapping(expr, &num).await?;219 let str: String = serde_json::from_str(&n)?;220 Ok(str)221 }222 async fn execute_expression_to_json<V: DeserializeOwned>(223 &mut self,224 expr: impl AsRef<[u8]>,225 ) -> Result<V> {226 let mut fexpr = b"builtins.toJSON (".to_vec();227 fexpr.extend_from_slice(expr.as_ref());228 fexpr.push(b')');229 let v = self.execute_expression_string(fexpr).await?;230 Ok(serde_json::from_str(&v)?)231 }232 async fn execute_expression_wrapping(233 &mut self,234 expr: impl AsRef<[u8]>,235 wrapping: &(String, String),236 ) -> Result<String> {237 let collected = Arc::new(Mutex::new(vec![]));238 let (etx, erx) = oneshot::channel();239 let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});240 let res = self.execute_expression_raw(expr).await?;241 let _ = self.execute_expression_raw(ERROR_DELIMITER).await?;242 let _ = erx.await;243 if res.is_empty() {244 let c = collected.lock().unwrap();245 if c.is_empty() {246 bail!("expected expression, got nothing")247 }248 bail!("{}", c.join("\n"));249 }250 drop(_collector);251 let Some(res) = res.strip_prefix(&wrapping.0) else {252 bail!("invalid type")253 };254 let Some(res) = res.strip_suffix(&wrapping.1) else {255 bail!("invalid type")256 };257 Ok(res.to_owned())258 }259 async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {260 let collected = Arc::new(Mutex::new(vec![]));261 let (etx, erx) = oneshot::channel();262 let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});263 let v = self.execute_expression_raw(expr).await?;264 let _ = self.execute_expression_raw(ERROR_DELIMITER).await;265 let _ = erx.await;266267 let c = collected.lock().unwrap();268 if !c.is_empty() {269 bail!("{}", c.join("\n"));270 }271 ensure!(v.is_empty(), "unexpected expression result");272 Ok(())273 }274 async fn execute_expression_raw(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {275 self.send_command(expr).await?;276 277 self.send_command(REPL_DELIMITER).await?;278 self.read_until_delimiter().await279 }280 async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {281 let id = self.allocate_id();282 self.execute_expression_empty(format!("sess_field_{id} = {}", expr.as_ref()))283 .await?;284 Ok(id)285 }286287 288 fn allocate_id(&mut self) -> u32 {289 if let Some(free) = self.free_list.pop() {290 free291 } else {292 let v = self.next_id;293 self.next_id += 1;294 v295 }296 }297 298 299 300 301 302 303 304}305306#[derive(Clone)]307pub struct NixSession(Arc<tokio::sync::Mutex<PooledConnection<NixSessionPoolInner>>>);308309#[derive(Clone, Debug)]310enum Index {311 String(String),312 313}314pub struct Field {315 full_path: Vec<Index>,316 session: NixSession,317 value: Option<u32>,318}319impl Field {320 fn root(session: NixSession) -> Self {321 Self {322 full_path: vec![],323 session,324 value: None,325 }326 }327 pub async fn field(session: NixSession, field: &str) -> Result<Self> {328 Self::root(session).get_field_deep([field]).await329 }330 pub async fn get_field(&self, name: &str) -> Result<Self> {331 self.get_field_deep([name]).await332 }333 pub async fn get_field_deep<'a>(334 &self,335 name: impl IntoIterator<Item = &'a str>,336 ) -> Result<Self> {337 let mut iter = name.into_iter();338339 let mut full_path = self.full_path.clone();340 let mut query = if let Some(id) = self.value {341 format!("sess_field_{id}")342 } else {343 let first = iter.next().expect("name not empty");344 ensure!(345 !(first.contains('.') | first.contains(' ')),346 "bad name for root query: {first}"347 );348 full_path.push(Index::String(first.to_string()));349 first.to_string()350 };351 for v in iter {352 full_path.push(Index::String(v.to_string()));353 354 let escaped = nixlike::serialize(v)?;355 let escaped = escaped.trim();356 query.push('.');357 query.push_str(escaped);358 }359360 let vid = self361 .session362 .0363 .lock()364 .await365 .execute_assign(&query)366 .await367 .with_context(|| format!("full path: {:?}", full_path))?;368 Ok(Self {369 full_path,370 session: self.session.clone(),371 value: Some(vid),372 })373 }374 pub async fn as_json<V: DeserializeOwned>(&self) -> Result<V> {375 let id = self.value.expect("can't serialize root field");376 self.session377 .0378 .lock()379 .await380 .execute_expression_to_json(&format!("sess_field_{id}"))381 .await382 .with_context(|| format!("full path: {:?}", self.full_path))383 }384 pub async fn list_fields(&self) -> Result<Vec<String>> {385 let id = self.value.expect("can't list root fields");386 self.session387 .0388 .lock()389 .await390 .execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}"))391 .await392 .with_context(|| format!("full path: {:?}", self.full_path))393 }394}395impl Drop for Field {396 fn drop(&mut self) {397 if let Some(id) = self.value {398 if let Ok(mut lock) = self.session.0.try_lock() {399 lock.free_list.push(id)400 }401 402 }403 }404}405struct NixSessionPoolInner {406 flake: OsString,407 nix_args: Vec<OsString>,408}409410#[derive(Debug)]411pub struct NixPoolError(anyhow::Error);412impl From<anyhow::Error> for NixPoolError {413 fn from(value: anyhow::Error) -> Self {414 Self(value)415 }416}417impl std::error::Error for NixPoolError {}418impl std::fmt::Display for NixPoolError {419 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {420 self.0.fmt(f)421 }422}423impl r2d2::ManageConnection for NixSessionPoolInner {424 type Connection = NixSessionInner;425 type Error = NixPoolError;426 fn connect(&self) -> std::result::Result<Self::Connection, Self::Error> {427 let _v = TOKIO_RUNTIME428 .get()429 .expect("missed tokio runtime init!")430 .enter();431 Ok(futures::executor::block_on(NixSessionInner::new(432 self.flake.as_os_str(),433 self.nix_args.iter().map(OsString::as_os_str),434 ))?)435 }436437 fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> {438 let _v = TOKIO_RUNTIME439 .get()440 .expect("missed tokio runtime init!")441 .enter();442 let res = futures::executor::block_on(conn.execute_expression_number("2 + 2"))?;443 if res != 4 {444 return Err(anyhow!("sanity check failed").into());445 };446 Ok(())447 }448449 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {450 false451 }452}453pub struct NixSessionPool(Pool<NixSessionPoolInner>);454impl NixSessionPool {455 pub async fn new(flake: OsString, nix_args: Vec<OsString>) -> Result<Self> {456 let inner = tokio::task::block_in_place(|| {457 r2d2::Builder::<NixSessionPoolInner>::new()458 .min_idle(Some(0))459 .build(NixSessionPoolInner { flake, nix_args })460 })?;461 Ok(Self(inner))462 }463 pub async fn get(&self) -> Result<NixSession> {464 let v = tokio::task::block_in_place(|| self.0.get())?;465 Ok(NixSession(Arc::new(tokio::sync::Mutex::new(v))))466 }467}468469pub static TOKIO_RUNTIME: OnceLock<tokio::runtime::Handle> = OnceLock::new();