diff --git a/src/main.rs b/src/main.rs index f99773b..37b5fda 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ extern crate nix; #[cfg(windows)] extern crate named_pipe; +use std::convert::TryInto; use std::io::{self, Read, Write}; use std::thread; use byteorder::{ByteOrder, NativeEndian, WriteBytesExt}; @@ -11,23 +12,23 @@ mod proxy_socket; use proxy_socket::ProxySocket; -const BUFFER_SIZE: u32 = 1024 ^ 2; // 1024 ^ 2 is the maximum +const BUFFER_SIZE: usize = 1024 ^ 2; // 1024 ^ 2 is the maximum -fn valid_length(length: u32) -> bool { - return length > 0 && length <= BUFFER_SIZE; +fn valid_length(length: usize) -> bool { + length > 0 && length <= BUFFER_SIZE } -fn read_header() -> u32 { +fn read_header() -> usize { let stdin = io::stdin(); let mut buf = vec![0; 4]; let mut handle = stdin.lock(); handle.read_exact(&mut buf).unwrap(); - NativeEndian::read_u32(&buf) + NativeEndian::read_u32(&buf).try_into().unwrap() } -fn read_body(length: u32, socket: &mut ProxySocket) { - let mut buffer = vec![0; length as usize]; +fn read_body(length: usize, socket: &mut ProxySocket) { + let mut buffer = vec![0; length]; let stdin = io::stdin(); let mut handle = stdin.lock(); @@ -39,7 +40,7 @@ fn read_body(length: u32, socket: &mut ProxySocket) { } fn read_response(socket: &mut ProxySocket) { - let mut buf = vec![0; BUFFER_SIZE as usize]; + let mut buf = vec![0; BUFFER_SIZE]; if let Ok(len) = socket.read(&mut buf) { write_response(&buf[0..len]); } diff --git a/src/proxy_socket.rs b/src/proxy_socket.rs index b78ac5a..fa6fd33 100644 --- a/src/proxy_socket.rs +++ b/src/proxy_socket.rs @@ -4,6 +4,7 @@ use std::io::{self, Read, Write}; #[cfg(not(windows))] use std::os::unix::io::AsRawFd; use std::os::unix::net::UnixStream; +use std::path::PathBuf; use nix::sys::socket; use nix::sys::socket::sockopt::SndBuf; use nix::sys::socket::sockopt::RcvBuf; @@ -32,7 +33,7 @@ impl Write for ProxySocket { } #[cfg(windows)] -pub fn connect(buffer_size: u32) -> io::Result> { +pub fn connect(buffer_size: usize) -> io::Result> { let username = env::var("USERNAME").unwrap(); let pipe_name = format!("\\\\.\\pipe\\keepassxc\\{}\\org.keepassxc.KeePassXC.BrowserServer", username); let client = PipeClient::connect(pipe_name)?; @@ -40,18 +41,46 @@ pub fn connect(buffer_size: u32) -> io::Result> { } #[cfg(not(windows))] -pub fn connect(buffer_size: u32) -> io::Result> { +/// Returns the directories where the socket could possible be located. +/// +/// These directories should be tried in sequence, until one of them is found +/// to contain the socket. +fn get_socket_dirs() -> Vec { + let mut dirs = Vec::new(); + + if !cfg!(target_os = "macos") { + if let Ok(dir) = env::var("XDG_RUNTIME_DIR") { + let xdg_runtime_dir: PathBuf = dir.into(); + + // Sandbox-friendly path. + // Used in KeePassXC >= 2.7.2 and for all versions on Flatpak. + dirs.push(xdg_runtime_dir.join("app/org.keepassxc.KeePassXC/")); + + // Legacy path. + // Used by KeePassXC < 2.7.2. + dirs.push(xdg_runtime_dir); + }; + }; + + // Default for macOS, and final fallback for Linux. + dirs.push(env::temp_dir()); + + dirs +} + +#[cfg(not(windows))] +pub fn connect(buffer_size: usize) -> io::Result> { use std::time::Duration; let socket_name = "org.keepassxc.KeePassXC.BrowserServer"; - let socket = if let Ok(dir) = if cfg!(target_os = "macos") {env::var("TMPDIR") } else { env::var("XDG_RUNTIME_DIR") } { - format!("{}/{}", dir, socket_name) - } else { - format!("/tmp/{}", socket_name) - }; - let s = UnixStream::connect(socket)?; - socket::setsockopt(s.as_raw_fd(), SndBuf, &(buffer_size as usize)).expect("setsockopt for SndBuf failed"); - socket::setsockopt(s.as_raw_fd(), RcvBuf, &(buffer_size as usize)).expect("setsockopt for RcvBuf failed"); + let dirs = get_socket_dirs(); + let s = dirs + .iter() + .find_map(|dir| UnixStream::connect(dir.join(socket_name)).ok()) + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + + socket::setsockopt(s.as_raw_fd(), SndBuf, &buffer_size).expect("setsockopt for SndBuf failed"); + socket::setsockopt(s.as_raw_fd(), RcvBuf, &buffer_size).expect("setsockopt for RcvBuf failed"); let timeout: Option = Some(Duration::from_secs(1)); s.set_read_timeout(timeout)?; Ok(ProxySocket { inner: s })