git.delta.rocks / remowt / refs/commits / 5cb6be498b69

difftreelog

feat properly cancel agent task

wplzsrwuYaroslav Bolyukin2024-08-12parent: #9ee216a.patch.diff
in: trunk

3 files changed

modifiedCargo.lockdiffbeforeafterboth
before · Cargo.lock
176 packageslockfile v3
modifiedcmds/remowt-agent/src/main.rsdiffbeforeafterboth
--- a/cmds/remowt-agent/src/main.rs
+++ b/cmds/remowt-agent/src/main.rs
@@ -2,14 +2,14 @@
 use std::collections::{BTreeMap, HashMap};
 use std::io::{stdout, Write};
 use std::marker::PhantomData;
-use std::sync::{Mutex, RwLock};
+use std::sync::{Arc, Mutex, OnceLock};
 use std::{future, process};
 
 use clap::Parser;
 use polkit_shared::{emphasize, BackendRequest, Identity, PidDisplay};
 use tokio::runtime::Handle;
 use tokio::task::{AbortHandle, JoinHandle, LocalSet};
-use tracing::trace;
+use tracing::{info, trace};
 use ui_prompt::dbus::DbusPrompterInterface;
 use ui_prompt::rofi::RofiPrompter;
 use ui_prompt::{PrependSourcePrompter, Prompter, Source};
@@ -58,16 +58,34 @@
     }
 }
 
+struct CancelTaskOnDrop {
+    tasks: Arc<Mutex<HashMap<String, AbortHandle>>>,
+    handle: String,
+}
+impl Drop for CancelTaskOnDrop {
+    fn drop(&mut self) {
+        info!("cancel on drop");
+        if let Some(task) = self
+            .tasks
+            .lock()
+            .expect("not poisoned")
+            .remove(&self.handle)
+        {
+            task.abort();
+        }
+    }
+}
+
 struct Agent {
     helper: PolkitHelperProxy<'static>,
-    tasks: Mutex<HashMap<String, AbortHandle>>,
+    tasks: Arc<Mutex<HashMap<String, AbortHandle>>>,
     connection: Connection,
 }
 impl Agent {
     async fn new(connection: Connection) -> anyhow::Result<Self> {
         Ok(Self {
             helper: PolkitHelperProxy::new(&connection).await?,
-            tasks: Mutex::new(HashMap::new()),
+            tasks: Arc::new(Mutex::new(HashMap::new())),
             connection,
         })
     }
@@ -78,7 +96,7 @@
     /// BeginAuthentication method
     #[allow(clippy::too_many_arguments)]
     async fn begin_authentication(
-        &mut self,
+        &self,
         action_id: String,
         message: String,
         icon_name: String,
@@ -87,12 +105,15 @@
         identities: Vec<Identity>,
     ) -> zbus::fdo::Result<()> {
         use std::fmt::Write;
-        trace!("begin auth");
+        info!("begin auth");
+        let _cancel_guard = Arc::new(OnceLock::new());
         let task = {
             let connection = self.connection.clone();
             let helper = self.helper.clone();
             let cookie = cookie.clone();
+            let _cancel_guard = _cancel_guard.clone();
             tokio::task::spawn(async move {
+                let _cancel_guard = _cancel_guard.clone();
                 trace!("conversation task");
                 let mut description = format!("{message}\n\n<b>Action id:</b> {action_id}",);
                 if let Some(subject) = details.remove("polkit.caller-pid") {
@@ -121,6 +142,7 @@
                     identities.iter().map(|v| v.to_string()).collect();
                 let identity_displays: Vec<&str> =
                     identity_displays.iter().map(|v| v.as_str()).collect();
+                info!("choose identity");
                 let choosen_identity = match identity_displays.len() {
                     0 => {
                         return Err(fdo::Error::AuthFailed(
@@ -139,6 +161,7 @@
                             .await?
                     }
                 };
+                info!("identity chosen");
 
                 let _ = write!(
                     description,
@@ -148,7 +171,10 @@
                 prompter.description = description;
 
                 prompter.source.push(Source(Cow::Borrowed("polkit daemon")));
+                // let connection = Connection::system().await?;
+                // let helper = PolkitHelperProxy::new(&connection).await?;
                 let prompter = TemporaryPrompterInterface::new(connection, prompter).await;
+                info!("init conv");
                 helper
                     .init_conversation(
                         BackendRequest {
@@ -166,24 +192,26 @@
                 Ok(())
             })
         };
-
         self.tasks
             .lock()
             .unwrap()
             .insert(cookie.clone(), task.abort_handle());
-        let result = task.await.expect("join error");
-        // The only way to no reach this line, is to either panic in previous line, or if authorization cancelled,
-        // while cancellation will remove task by itself.
-        // TODO: But still it would be better to have abort guard, which will remove it from HashMap
-        self.tasks.lock().unwrap().remove(&cookie);
+        info!("abort handle stored");
+        let _ = _cancel_guard.set(CancelTaskOnDrop {
+            tasks: self.tasks.clone(),
+            handle: cookie.clone(),
+        });
+
+        let _ = task.await;
 
-        result
+        Ok(())
     }
 
     /// CancelAuthentication method
     async fn cancel_authentication(&self, cookie: &str) -> zbus::fdo::Result<()> {
-        trace!("cancel auth");
+        info!("auth cancelled");
         if let Some(abort) = self.tasks.lock().unwrap().remove(cookie) {
+            info!("abort handle found");
             abort.abort();
         }
         // debug!("Authentication cancled ! {cookie}");
modifiednix/nixos-modules.nixdiffbeforeafterboth
--- a/nix/nixos-modules.nix
+++ b/nix/nixos-modules.nix
@@ -10,6 +10,9 @@
       ];
       systemd.services.remowt-polkit-helper = {
         aliases = ["dbus-lach.polkit.helper1.service"];
+        # Restarting can kill in-progress auth requests.
+        # It is good to have it restarted for security, but I didn't decided on the flow yet, graceful shutdown?..
+        unitConfig.X-RestartIfChanged = false;
       };
     };
   };