git.delta.rocks / jrsonnet / refs/commits / e85b4da8a439

difftreelog

source

cmds/fleet/src/better_nix_eval.rs14.1 KiBsourcehistory
1use std::ffi::{OsStr, OsString};2use std::fmt::Display;3use std::process::Stdio;4use std::sync::{Arc, OnceLock};56use anyhow::{anyhow, bail, ensure, Context, Result};7use futures::StreamExt;8use itertools::Itertools;9use r2d2::{Pool, PooledConnection};10use serde::de::DeserializeOwned;11use serde::Deserialize;12use tokio::io::AsyncWriteExt;13use tokio::process::{ChildStderr, ChildStdin, ChildStdout, Command};14use tokio::select;15use tokio::sync::{mpsc, oneshot};16use tokio_util::codec::{FramedRead, LinesCodec};17use tracing::{debug, error, warn};1819use crate::command::{ClonableHandler, Handler, NixHandler, NoopHandler};2021const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";2223pub struct NixSessionInner {24	full_delimiter: String,25	nix_handler: ClonableHandler<NixHandler>,26	out: OutputHandler,27	stdin: ChildStdin,28	string_wrapping: (String, String),29	number_wrapping: (String, String),3031	next_id: u32,32	free_list: Vec<u32>,33}34const TRAIN_STRING: &str = "\"TRAIN_STRING\"";35const TRAIN_NUMBER: &str = "13141516";3637#[must_use]38struct ErrorCollector<'i, H> {39	collected: Vec<String>,40	inner: &'i mut H,41}42impl<'i, H> ErrorCollector<'i, H> {43	fn new(inner: &'i mut H) -> Self {44		Self {45			collected: vec![],46			inner,47		}48	}49}50impl<H> ErrorCollector<'_, H> {51	fn handle_line_inner(&mut self, msg: &str) -> bool {52		let Some(msg) = msg.strip_prefix("@nix ") else {53			return false;54		};55		#[derive(Deserialize)]56		struct ErrorAction {57			action: String,58			level: u32,59			msg: String,60		}61		let Ok(act) = serde_json::from_str::<ErrorAction>(msg) else {62			return false;63		};64		if act.action != "msg" || act.level != 0 {65			return false;66		}67		self.collected.push(act.msg);68		true69	}70	fn finish(self) -> Result<()> {71		// fn dedent(s: String) -> String {72		// 	s.split('\n').filter(|s| !s.trim().is_empty()).map(|v| v.)73		// }74		if !self.collected.is_empty() {75			bail!("{}", self.collected.iter().map(|v| {76				if let Some(f) = v.strip_prefix("\u{1b}[31;1merror:\u{1b}[0m ") {77					let v = unindent::unindent(f.trim_start());78					v.trim().to_owned()79				} else {80					v.to_owned()81				}82			}).join("\n"));83		}84		Ok(())85	}86	fn flush(self) {87		for line in self.collected {88			warn!("{line}");89		}90	}91}92impl<H: Handler> Handler for ErrorCollector<'_, H> {93	fn handle_line(&mut self, e: &str) {94		if self.handle_line_inner(e) {95			return;96		}97		self.inner.handle_line(e)98	}99}100101enum OutputLine {102	Out(String),103	Err(String),104}105struct OutputHandler {106	rx: mpsc::Receiver<OutputLine>,107	_cancel_handle: oneshot::Receiver<()>,108}109impl OutputHandler {110	fn new(out: ChildStdout, err: ChildStderr) -> Self {111		let mut out = FramedRead::new(out, LinesCodec::new());112		let mut err = FramedRead::new(err, LinesCodec::new());113		let (tx, rx) = mpsc::channel(20);114		let (mut cancelled, _cancel_handle) = oneshot::channel();115		tokio::spawn(async move {116			loop {117				select! {118					// We should receive errors earlier than synchronization119					biased;120					e = err.next() => {121						let Some(Ok(e)) = e else {122							if e.is_some() {123								error!("bad repl stderr: {e:?}");124							}125							continue;126						};127						let _ = tx.send(OutputLine::Err(e)).await;128					}129					o = out.next() => {130						let Some(Ok(o)) = o else {131							if o.is_some() {132								error!("bad repl stdout: {o:?}");133							}134							continue;135						};136						let _ = tx.send(OutputLine::Out(o)).await;137					}138					// Reader doesn't care about stdout, as this is cancelled.139					// Error still might be useful, to process leftover span closures?140					_ = cancelled.closed() => {141						break;142					}143				}144			}145		});146		Self { rx, _cancel_handle }147	}148	async fn next(&mut self) -> Option<OutputLine> {149		self.rx.recv().await150	}151}152153impl NixSessionInner {154	async fn new(flake: &OsStr, extra_args: impl IntoIterator<Item = &OsStr>) -> Result<Self> {155		let mut cmd = Command::new("nix");156		cmd.arg("repl")157			.arg(flake)158			.arg("--log-format")159			.arg("internal-json");160		for arg in extra_args {161			cmd.arg(arg);162		}163		cmd.stdin(Stdio::piped());164		cmd.stdout(Stdio::piped());165		cmd.stderr(Stdio::piped());166		let cmd = cmd.spawn()?;167		let stdout = cmd.stdout.unwrap();168		let stderr = cmd.stderr.unwrap();169		let mut out = OutputHandler::new(stdout, stderr);170		let mut stdin = cmd.stdin.unwrap();171		// Standard repl hello doesn't work with internal-json logger172		stdin.write_all(REPL_DELIMITER.as_bytes()).await?;173		stdin.write_all(b"\n").await?;174		stdin.flush().await?;175		let nix_handler = NixHandler::default();176		let mut full_delimiter = None;177		while let Some(line) = out.next().await {178			let line = match line {179				OutputLine::Out(o) => o,180				OutputLine::Err(_e) => {181					// Handle startup errors, but skip repl hello?182					//nix_handler.handle_line(&e);183					continue;184				}185			};186			if line.contains(REPL_DELIMITER) {187				debug!("discovered repl delimiter with added colors: {line}");188				full_delimiter = Some(line.to_owned());189				break;190			}191		}192		let Some(full_delimiter) = full_delimiter else {193			bail!("failed to discover delimiter");194		};195		let mut res = Self {196			full_delimiter,197			nix_handler: ClonableHandler::new(nix_handler),198			out,199			stdin,200			string_wrapping: Default::default(),201			number_wrapping: Default::default(),202203			next_id: 0,204			free_list: vec![],205		};206		res.train().await?;207		Ok(res)208	}209	async fn train(&mut self) -> Result<()> {210		{211			let full_string = self212				.execute_expression_raw(TRAIN_STRING, &mut NoopHandler)213				.await?;214			let string_offset = full_string.find(TRAIN_STRING).expect("contained");215			let string_prefix = &full_string[..string_offset];216			let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];217			self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());218		}219		{220			let full_number = self221				.execute_expression_raw(TRAIN_NUMBER, &mut NoopHandler)222				.await?;223			let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");224			let number_prefix = &full_number[..number_offset];225			let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];226			self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());227		}228		Ok(())229	}230	async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {231		self.stdin.write_all(cmd.as_ref()).await?;232		self.stdin.write_all(b"\n").await?;233		Ok(())234	}235	async fn read_until_delimiter(&mut self, err_handler: &mut dyn Handler) -> Result<String> {236		let mut out = String::new();237		while let Some(line) = self.out.next().await {238			let line = match line {239				OutputLine::Out(out) => out,240				OutputLine::Err(err) => {241					err_handler.handle_line(&err);242					continue;243				}244			};245			if line == self.full_delimiter {246				return Ok(out);247			}248			if !out.is_empty() {249				out.push('\n');250			}251			out.push_str(&line);252		}253		bail!("didn't reached delimiter");254	}255	async fn execute_expression_number(&mut self, expr: impl AsRef<[u8]>) -> Result<u64> {256		let num = self.number_wrapping.clone();257		let n = self.execute_expression_wrapping(expr, &num).await?;258		Ok(n.parse::<u64>()?)259	}260	async fn execute_expression_string(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {261		let num = self.string_wrapping.clone();262		let n = self.execute_expression_wrapping(expr, &num).await?;263		let str: String = serde_json::from_str(&n)?;264		Ok(str)265	}266	async fn execute_expression_to_json<V: DeserializeOwned>(267		&mut self,268		expr: impl AsRef<[u8]>,269	) -> Result<V> {270		let mut fexpr = b"builtins.toJSON (".to_vec();271		fexpr.extend_from_slice(expr.as_ref());272		fexpr.push(b')');273		let v = self.execute_expression_string(fexpr).await?;274		Ok(serde_json::from_str(&v)?)275	}276	async fn execute_expression_wrapping(277		&mut self,278		expr: impl AsRef<[u8]>,279		wrapping: &(String, String),280	) -> Result<String> {281		let mut nix_handler = self.nix_handler.clone();282		let mut collected = ErrorCollector::new(&mut nix_handler);283		let res = self.execute_expression_raw(expr, &mut collected).await?;284		if res.is_empty() {285			collected.finish()?;286			bail!("expected expression, got nothing")287		} else {288			collected.flush()289		};290		let Some(res) = res.strip_prefix(&wrapping.0) else {291			bail!("invalid type")292		};293		let Some(res) = res.strip_suffix(&wrapping.1) else {294			bail!("invalid type")295		};296		Ok(res.to_owned())297	}298	async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {299		let mut nix_handler = self.nix_handler.clone();300		let mut collected = ErrorCollector::new(&mut nix_handler);301		let v = self.execute_expression_raw(expr, &mut collected).await?;302		collected.finish()?;303		ensure!(v.is_empty(), "unexpected expression result");304		Ok(())305	}306	async fn execute_expression_raw(307		&mut self,308		expr: impl AsRef<[u8]>,309		err_handler: &mut dyn Handler,310	) -> Result<String> {311		self.send_command(expr).await?;312		// It will be echoed313		self.send_command(REPL_DELIMITER).await?;314		self.read_until_delimiter(err_handler).await315	}316	async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {317		let id = self.allocate_id();318		self.execute_expression_empty(format!("sess_field_{id} = {}", expr.as_ref()))319			.await?;320		Ok(id)321	}322323	/// Id should be immediately used324	fn allocate_id(&mut self) -> u32 {325		if let Some(free) = self.free_list.pop() {326			free327		} else {328			let v = self.next_id;329			self.next_id += 1;330			v331		}332	}333	// Nix has no way to deallocate variable, yet GC will correct everything not reachable.334	// async fn free_id(&mut self, id: u32) -> Result<()> {335	// 	self.execute_expression_empty(format!("sess_field_{id} = null"))336	// 		.await?;337	// 	self.free_list.push(id);338	// 	Ok(())339	// }340}341342#[derive(Clone)]343pub struct NixSession(Arc<tokio::sync::Mutex<PooledConnection<NixSessionPoolInner>>>);344345#[derive(Clone)]346enum Index {347	String(String),348	// Idx(u32),349}350impl Display for Index {351	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {352		match self {353			Index::String(k) => {354				let v = nixlike::format_identifier(k.as_str());355				write!(f, ".{v}")356			}357		}358	}359}360struct PathDisplay<'i>(&'i [Index]);361impl Display for PathDisplay<'_> {362	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {363		write!(f, "flake")?;364		for i in self.0 {365			write!(f, "{i}")?;366		}367		Ok(())368	}369}370pub struct Field {371	full_path: Vec<Index>,372	session: NixSession,373	value: Option<u32>,374}375impl Field {376	fn root(session: NixSession) -> Self {377		Self {378			full_path: vec![],379			session,380			value: None,381		}382	}383	pub async fn field(session: NixSession, field: &str) -> Result<Self> {384		Self::root(session).get_field_deep([field]).await385	}386	pub async fn get_json_deep<'a, V: DeserializeOwned>(387		&self,388		name: impl IntoIterator<Item = &'a str>,389	) -> Result<V> {390		let field = self.get_field_deep(name).await?;391		field.as_json().await392	}393	pub async fn get_field(&self, name: &str) -> Result<Self> {394		self.get_field_deep([name]).await395	}396	pub async fn get_field_deep<'a>(397		&self,398		name: impl IntoIterator<Item = &'a str>,399	) -> Result<Self> {400		let mut iter = name.into_iter();401402		let mut full_path = self.full_path.clone();403		let mut query = if let Some(id) = self.value {404			format!("sess_field_{id}")405		} else {406			let first = iter.next().expect("name not empty");407			ensure!(408				!(first.contains('.') | first.contains(' ')),409				"bad name for root query: {first}"410			);411			full_path.push(Index::String(first.to_string()));412			first.to_string()413		};414		for v in iter {415			full_path.push(Index::String(v.to_string()));416			// Escape417			let escaped = nixlike::serialize(v)?;418			let escaped = escaped.trim();419			query.push('.');420			query.push_str(escaped);421		}422423		let vid = self424			.session425			.0426			.lock()427			.await428			.execute_assign(&query)429			.await430			.with_context(|| format!("full path: {}", PathDisplay(&full_path)))?;431		Ok(Self {432			full_path,433			session: self.session.clone(),434			value: Some(vid),435		})436	}437	pub async fn as_json<V: DeserializeOwned>(&self) -> Result<V> {438		let id = self.value.expect("can't serialize root field");439		self.session440			.0441			.lock()442			.await443			.execute_expression_to_json(&format!("sess_field_{id}"))444			.await445			.with_context(|| format!("full path: {}", PathDisplay(&self.full_path)))446	}447	pub async fn list_fields(&self) -> Result<Vec<String>> {448		let id = self.value.expect("can't list root fields");449		self.session450			.0451			.lock()452			.await453			.execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}"))454			.await455			.with_context(|| format!("full path: {}", PathDisplay(&self.full_path)))456	}457}458impl Drop for Field {459	fn drop(&mut self) {460		if let Some(id) = self.value {461			if let Ok(mut lock) = self.session.0.try_lock() {462				lock.free_list.push(id)463			}464			// Leaked465		}466	}467}468struct NixSessionPoolInner {469	flake: OsString,470	nix_args: Vec<OsString>,471}472473#[derive(Debug)]474pub struct NixPoolError(anyhow::Error);475impl From<anyhow::Error> for NixPoolError {476	fn from(value: anyhow::Error) -> Self {477		Self(value)478	}479}480impl std::error::Error for NixPoolError {}481impl std::fmt::Display for NixPoolError {482	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {483		self.0.fmt(f)484	}485}486impl r2d2::ManageConnection for NixSessionPoolInner {487	type Connection = NixSessionInner;488	type Error = NixPoolError;489	fn connect(&self) -> std::result::Result<Self::Connection, Self::Error> {490		let _v = TOKIO_RUNTIME491			.get()492			.expect("missed tokio runtime init!")493			.enter();494		Ok(futures::executor::block_on(NixSessionInner::new(495			self.flake.as_os_str(),496			self.nix_args.iter().map(OsString::as_os_str),497		))?)498	}499500	fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> {501		let _v = TOKIO_RUNTIME502			.get()503			.expect("missed tokio runtime init!")504			.enter();505		let res = futures::executor::block_on(conn.execute_expression_number("2 + 2"))?;506		if res != 4 {507			return Err(anyhow!("sanity check failed").into());508		};509		Ok(())510	}511512	fn has_broken(&self, _conn: &mut Self::Connection) -> bool {513		false514	}515}516pub struct NixSessionPool(Pool<NixSessionPoolInner>);517impl NixSessionPool {518	pub async fn new(flake: OsString, nix_args: Vec<OsString>) -> Result<Self> {519		let inner = tokio::task::block_in_place(|| {520			r2d2::Builder::<NixSessionPoolInner>::new()521				.min_idle(Some(0))522				.build(NixSessionPoolInner { flake, nix_args })523		})?;524		Ok(Self(inner))525	}526	pub async fn get(&self) -> Result<NixSession> {527		let v = tokio::task::block_in_place(|| self.0.get())?;528		Ok(NixSession(Arc::new(tokio::sync::Mutex::new(v))))529	}530}531532pub static TOKIO_RUNTIME: OnceLock<tokio::runtime::Handle> = OnceLock::new();