git.delta.rocks / jrsonnet / refs/commits / 19e4a93958ae

difftreelog

source

cmds/fleet/src/better_nix_eval.rs13.4 KiBsourcehistory
1use std::ffi::{OsStr, OsString};2use std::process::Stdio;3use std::sync::{Arc, Mutex, OnceLock};45use abort_on_drop::ChildTask;6use anyhow::{anyhow, bail, ensure, Context, Result};7use futures::StreamExt;8use r2d2::{Pool, PooledConnection};9use serde::de::DeserializeOwned;10use serde::Deserialize;11use tokio::io::AsyncWriteExt;12use tokio::process::{ChildStdin, ChildStdout, Command};13use tokio::sync::oneshot;14use tokio_util::codec::{FramedRead, LinesCodec};15use tracing::debug;1617use crate::command::{process_child_stderr, ErrorRecorder, ErrorRecorderT, NixHandler};1819const REPL_DELIMITER: &str = "\"FLEET_MAGIC_REPL_DELIMITER\"";20// To synchronize stderr and stdout. It works, yet I hate this.21// There is no other way to catch errors, because they are coming from different streams, and they are not synchronized in tokio.22const ERROR_DELIMITER: &str = "FLEET_MAGIC_ERROR_DELIMITER";2324pub struct NixSessionInner {25	full_delimiter: String,26	#[allow(dead_code)]27	stderr_handler: ChildTask<Result<()>>,28	error_recorder: ErrorRecorderT,29	read: FramedRead<ChildStdout, LinesCodec>,30	stdin: ChildStdin,31	string_wrapping: (String, String),32	number_wrapping: (String, String),33	error_delimiter: String,3435	next_id: u32,36	free_list: Vec<u32>,37}38const TRAIN_STRING: &str = "\"TRAIN_STRING\"";39const TRAIN_NUMBER: &str = "13141516";4041struct ErrorRecorderHandle {42	handle: ErrorRecorderT,43}44impl ErrorRecorderHandle {}45impl Drop for ErrorRecorderHandle {46	fn drop(&mut self) {47		let mut recorded = self.handle.lock().unwrap();48		assert!(recorded.is_some(), "exclusive");49		*recorded = None;50	}51}5253struct ErrorCollector {54	collected: Arc<Mutex<Vec<String>>>,55	delim: String,56	got_delim: Option<oneshot::Sender<()>>,57}58impl ErrorRecorder for ErrorCollector {59	fn push_message(&mut self, msg: &str) -> bool {60		if msg == self.delim {61			let _ = self.got_delim62				.take()63				.expect("error delim is only expected once")64				.send(());65			 return true;66		}67		let Some(msg) = msg.strip_prefix("@nix ") else {68			return false;69		};70		#[derive(Deserialize)]71		struct ErrorAction {72			action: String,73			level: u32,74			msg: String,75		}76		let Ok(act) = serde_json::from_str::<ErrorAction>(msg) else {77			return false;78		};79		if act.action != "msg" || act.level != 0 {80			return false;81		}82		self.collected.lock().unwrap().push(act.msg);83		true84	}85}8687impl NixSessionInner {88	async fn new(flake: &OsStr, extra_args: impl IntoIterator<Item = &OsStr>) -> Result<Self> {89		let mut cmd = Command::new("nix");90		cmd.arg("repl")91			.arg(flake)92			.arg("--log-format")93			.arg("internal-json");94		for arg in extra_args {95			cmd.arg(arg);96		}97		cmd.stdin(Stdio::piped());98		cmd.stdout(Stdio::piped());99		cmd.stderr(Stdio::piped());100		let cmd = cmd.spawn()?;101		let stdout = cmd.stdout.unwrap();102		let stderr = cmd.stderr.unwrap();103		let mut stdin = cmd.stdin.unwrap();104		let error_recorder = ErrorRecorderT::default();105		let err_recorder = error_recorder.clone();106		let stderr_handler = abort_on_drop::ChildTask::from(tokio::spawn(async move {107			let mut handler = NixHandler::default();108			process_child_stderr(stderr, &mut handler, err_recorder).await109		}));110		// Standard repl hello doesn't work with internal-json logger111		stdin.write_all(REPL_DELIMITER.as_bytes()).await?;112		stdin.write_all(b"\n").await?;113		stdin.flush().await?;114		let mut read = FramedRead::new(stdout, LinesCodec::new());115		let mut full_delimiter = None;116		while let Some(line) = read.next().await {117			let line = line?;118			if line.contains(REPL_DELIMITER) {119				debug!("discovered repl delimiter with added colors: {line}");120				full_delimiter = Some(line.to_owned());121				break;122			}123		}124		let Some(full_delimiter) = full_delimiter else {125			bail!("failed to discover delimiter");126		};127		let mut res = Self {128			full_delimiter,129			error_delimiter: "[[filled after training]]".to_owned(),130			stderr_handler,131			error_recorder,132			read,133			stdin,134			string_wrapping: Default::default(),135			number_wrapping: Default::default(),136137			next_id: 0,138			free_list: vec![],139		};140		res.train().await?;141		Ok(res)142	}143	async fn train(&mut self) -> Result<()> {144		{145			let full_string = self.execute_expression_raw(TRAIN_STRING).await?;146			let string_offset = full_string.find(TRAIN_STRING).expect("contained");147			let string_prefix = &full_string[..string_offset];148			let string_suffix = &full_string[string_offset + TRAIN_STRING.len()..];149			self.string_wrapping = (string_prefix.to_owned(), string_suffix.to_owned());150		}151		{152			let full_number = self.execute_expression_raw(TRAIN_NUMBER).await?;153			let number_offset = full_number.find(TRAIN_NUMBER).expect("contained");154			let number_prefix = &full_number[..number_offset];155			let number_suffix = &full_number[number_offset + TRAIN_NUMBER.len()..];156			self.number_wrapping = (number_prefix.to_owned(), number_suffix.to_owned());157		}158		{159			struct TrainingErrorCollector(Option<oneshot::Sender<String>>);160			impl ErrorRecorder for TrainingErrorCollector {161				fn push_message(&mut self, msg: &str) -> bool {162					if msg.contains(ERROR_DELIMITER) {163						let _ = self164							.0165							.take()166							.expect("error delimiter is sent once")167							.send(msg.to_owned());168					}169					true170				}171			}172			let (tx, rx) = oneshot::channel();173			let _handle = self.record_error(TrainingErrorCollector(Some(tx)));174			self.send_command(ERROR_DELIMITER).await?;175			self.send_command(REPL_DELIMITER).await?;176			self.read_until_delimiter().await?;177			let msg = rx.await?;178			self.error_delimiter = msg;179		}180		Ok(())181	}182	fn record_error(&mut self, v: impl ErrorRecorder + 'static) -> ErrorRecorderHandle {183		{184			let mut recorder = self.error_recorder.lock().unwrap();185			assert!(recorder.is_none(), "recorder is already started");186			*recorder = Some(Box::new(v));187		}188		ErrorRecorderHandle {189			handle: self.error_recorder.clone(),190		}191	}192	async fn send_command(&mut self, cmd: impl AsRef<[u8]>) -> Result<()> {193		self.stdin.write_all(cmd.as_ref()).await?;194		self.stdin.write_all(b"\n").await?;195		Ok(())196	}197	async fn read_until_delimiter(&mut self) -> Result<String> {198		let mut out = String::new();199		while let Some(line) = self.read.next().await {200			let line = line?;201			if line == self.full_delimiter {202				return Ok(out);203			}204			if !out.is_empty() {205				out.push('\n');206			}207			out.push_str(&line);208		}209		bail!("didn't reached delimiter");210	}211	async fn execute_expression_number(&mut self, expr: impl AsRef<[u8]>) -> Result<u64> {212		let num = self.number_wrapping.clone();213		let n = self.execute_expression_wrapping(expr, &num).await?;214		Ok(n.parse::<u64>()?)215	}216	async fn execute_expression_string(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {217		let num = self.string_wrapping.clone();218		let n = self.execute_expression_wrapping(expr, &num).await?;219		let str: String = serde_json::from_str(&n)?;220		Ok(str)221	}222	async fn execute_expression_to_json<V: DeserializeOwned>(223		&mut self,224		expr: impl AsRef<[u8]>,225	) -> Result<V> {226		let mut fexpr = b"builtins.toJSON (".to_vec();227		fexpr.extend_from_slice(expr.as_ref());228		fexpr.push(b')');229		let v = self.execute_expression_string(fexpr).await?;230		Ok(serde_json::from_str(&v)?)231	}232	async fn execute_expression_wrapping(233		&mut self,234		expr: impl AsRef<[u8]>,235		wrapping: &(String, String),236	) -> Result<String> {237		let collected = Arc::new(Mutex::new(vec![]));238		let (etx, erx) = oneshot::channel();239		let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});240		let res = self.execute_expression_raw(expr).await?;241		let _ = self.execute_expression_raw(ERROR_DELIMITER).await?;242		let _ = erx.await;243		if res.is_empty() {244			let c = collected.lock().unwrap();245			if c.is_empty() {246				bail!("expected expression, got nothing")247			}248			bail!("{}", c.join("\n"));249		}250		drop(_collector);251		let Some(res) = res.strip_prefix(&wrapping.0) else {252			bail!("invalid type")253		};254		let Some(res) = res.strip_suffix(&wrapping.1) else {255			bail!("invalid type")256		};257		Ok(res.to_owned())258	}259	async fn execute_expression_empty(&mut self, expr: impl AsRef<[u8]>) -> Result<()> {260		let collected = Arc::new(Mutex::new(vec![]));261		let (etx, erx) = oneshot::channel();262		let _collector = self.record_error(ErrorCollector{collected:collected.clone(), delim: self.error_delimiter.clone(), got_delim: Some(etx)});263		let v = self.execute_expression_raw(expr).await?;264		let _ = self.execute_expression_raw(ERROR_DELIMITER).await;265		let _ = erx.await;266267		let c = collected.lock().unwrap();268		if !c.is_empty() {269			bail!("{}", c.join("\n"));270		}271		ensure!(v.is_empty(), "unexpected expression result");272		Ok(())273	}274	async fn execute_expression_raw(&mut self, expr: impl AsRef<[u8]>) -> Result<String> {275		self.send_command(expr).await?;276		// It will be echoed277		self.send_command(REPL_DELIMITER).await?;278		self.read_until_delimiter().await279	}280	async fn execute_assign(&mut self, expr: impl AsRef<str>) -> Result<u32> {281		let id = self.allocate_id();282		self.execute_expression_empty(format!("sess_field_{id} = {}", expr.as_ref()))283			.await?;284		Ok(id)285	}286287	/// Id should be immediately used288	fn allocate_id(&mut self) -> u32 {289		if let Some(free) = self.free_list.pop() {290			free291		} else {292			let v = self.next_id;293			self.next_id += 1;294			v295		}296	}297	// Nix has no way to deallocate variable, yet GC will correct everything not reachable.298	// async fn free_id(&mut self, id: u32) -> Result<()> {299	// 	self.execute_expression_empty(format!("sess_field_{id} = null"))300	// 		.await?;301	// 	self.free_list.push(id);302	// 	Ok(())303	// }304}305306#[derive(Clone)]307pub struct NixSession(Arc<tokio::sync::Mutex<PooledConnection<NixSessionPoolInner>>>);308309#[derive(Clone, Debug)]310enum Index {311	String(String),312	// Idx(u32),313}314pub struct Field {315	full_path: Vec<Index>,316	session: NixSession,317	value: Option<u32>,318}319impl Field {320	fn root(session: NixSession) -> Self {321		Self {322			full_path: vec![],323			session,324			value: None,325		}326	}327	pub async fn field(session: NixSession, field: &str) -> Result<Self> {328		Self::root(session).get_field_deep([field]).await329	}330	pub async fn get_field(&self, name: &str) -> Result<Self> {331		self.get_field_deep([name]).await332	}333	pub async fn get_field_deep<'a>(334		&self,335		name: impl IntoIterator<Item = &'a str>,336	) -> Result<Self> {337		let mut iter = name.into_iter();338339		let mut full_path = self.full_path.clone();340		let mut query = if let Some(id) = self.value {341			format!("sess_field_{id}")342		} else {343			let first = iter.next().expect("name not empty");344			ensure!(345				!(first.contains('.') | first.contains(' ')),346				"bad name for root query: {first}"347			);348			full_path.push(Index::String(first.to_string()));349			first.to_string()350		};351		for v in iter {352			full_path.push(Index::String(v.to_string()));353			// Escape354			let escaped = nixlike::serialize(v)?;355			let escaped = escaped.trim();356			query.push('.');357			query.push_str(escaped);358		}359360		let vid = self361			.session362			.0363			.lock()364			.await365			.execute_assign(&query)366			.await367			.with_context(|| format!("full path: {:?}", full_path))?;368		Ok(Self {369			full_path,370			session: self.session.clone(),371			value: Some(vid),372		})373	}374	pub async fn as_json<V: DeserializeOwned>(&self) -> Result<V> {375		let id = self.value.expect("can't serialize root field");376		self.session377			.0378			.lock()379			.await380			.execute_expression_to_json(&format!("sess_field_{id}"))381			.await382			.with_context(|| format!("full path: {:?}", self.full_path))383	}384	pub async fn list_fields(&self) -> Result<Vec<String>> {385		let id = self.value.expect("can't list root fields");386		self.session387			.0388			.lock()389			.await390			.execute_expression_to_json(&format!("builtins.attrNames sess_field_{id}"))391			.await392			.with_context(|| format!("full path: {:?}", self.full_path))393	}394}395impl Drop for Field {396	fn drop(&mut self) {397		if let Some(id) = self.value {398			if let Ok(mut lock) = self.session.0.try_lock() {399				lock.free_list.push(id)400			}401			// Leaked402		}403	}404}405struct NixSessionPoolInner {406	flake: OsString,407	nix_args: Vec<OsString>,408}409410#[derive(Debug)]411pub struct NixPoolError(anyhow::Error);412impl From<anyhow::Error> for NixPoolError {413	fn from(value: anyhow::Error) -> Self {414		Self(value)415	}416}417impl std::error::Error for NixPoolError {}418impl std::fmt::Display for NixPoolError {419	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {420		self.0.fmt(f)421	}422}423impl r2d2::ManageConnection for NixSessionPoolInner {424	type Connection = NixSessionInner;425	type Error = NixPoolError;426	fn connect(&self) -> std::result::Result<Self::Connection, Self::Error> {427		let _v = TOKIO_RUNTIME428			.get()429			.expect("missed tokio runtime init!")430			.enter();431		Ok(futures::executor::block_on(NixSessionInner::new(432			self.flake.as_os_str(),433			self.nix_args.iter().map(OsString::as_os_str),434		))?)435	}436437	fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> {438		let _v = TOKIO_RUNTIME439			.get()440			.expect("missed tokio runtime init!")441			.enter();442		let res = futures::executor::block_on(conn.execute_expression_number("2 + 2"))?;443		if res != 4 {444			return Err(anyhow!("sanity check failed").into());445		};446		Ok(())447	}448449	fn has_broken(&self, _conn: &mut Self::Connection) -> bool {450		false451	}452}453pub struct NixSessionPool(Pool<NixSessionPoolInner>);454impl NixSessionPool {455	pub async fn new(flake: OsString, nix_args: Vec<OsString>) -> Result<Self> {456		let inner = tokio::task::block_in_place(|| {457			r2d2::Builder::<NixSessionPoolInner>::new()458				.min_idle(Some(0))459				.build(NixSessionPoolInner { flake, nix_args })460		})?;461		Ok(Self(inner))462	}463	pub async fn get(&self) -> Result<NixSession> {464		let v = tokio::task::block_in_place(|| self.0.get())?;465		Ok(NixSession(Arc::new(tokio::sync::Mutex::new(v))))466	}467}468469pub static TOKIO_RUNTIME: OnceLock<tokio::runtime::Handle> = OnceLock::new();