1use std::io;2use std::pin::Pin;3use std::task::{Context, Poll};45use anyhow::{anyhow, Result};6use camino::Utf8PathBuf;7use russh::client::Msg;8use russh::{Channel, ChannelStream};9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};10use tokio::net::{UnixListener, UnixStream};11use tokio::sync::oneshot;1213pub enum RemowtListener {14 Ssh(oneshot::Receiver<Channel<Msg>>),15 Local(UnixListener, Utf8PathBuf),16}1718impl RemowtListener {19 pub async fn accept(self) -> Result<RemowtStream> {20 match self {21 RemowtListener::Ssh(rx) => {22 let ch = rx23 .await24 .map_err(|_| anyhow!("agent never connected the forwarded socket"))?;25 Ok(RemowtStream::Ssh(ch.into_stream()))26 }27 RemowtListener::Local(listener, path) => {28 let (stream, _) = listener.accept().await?;29 let _ = std::fs::remove_file(&path);30 Ok(RemowtStream::Local(stream))31 }32 }33 }34}3536pub enum RemowtStream {37 Ssh(ChannelStream<Msg>),38 Local(UnixStream),39}4041impl AsyncRead for RemowtStream {42 fn poll_read(43 self: Pin<&mut Self>,44 cx: &mut Context<'_>,45 buf: &mut ReadBuf<'_>,46 ) -> Poll<io::Result<()>> {47 match self.get_mut() {48 RemowtStream::Ssh(s) => Pin::new(s).poll_read(cx, buf),49 RemowtStream::Local(s) => Pin::new(s).poll_read(cx, buf),50 }51 }52}5354impl AsyncWrite for RemowtStream {55 fn poll_write(56 self: Pin<&mut Self>,57 cx: &mut Context<'_>,58 buf: &[u8],59 ) -> Poll<io::Result<usize>> {60 match self.get_mut() {61 RemowtStream::Ssh(s) => Pin::new(s).poll_write(cx, buf),62 RemowtStream::Local(s) => Pin::new(s).poll_write(cx, buf),63 }64 }6566 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {67 match self.get_mut() {68 RemowtStream::Ssh(s) => Pin::new(s).poll_flush(cx),69 RemowtStream::Local(s) => Pin::new(s).poll_flush(cx),70 }71 }7273 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {74 match self.get_mut() {75 RemowtStream::Ssh(s) => Pin::new(s).poll_shutdown(cx),76 RemowtStream::Local(s) => Pin::new(s).poll_shutdown(cx),77 }78 }79}