diff --git a/src/main.rs b/src/main.rs index b175d6b..a848782 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,75 +12,64 @@ mod proxy_socket; use proxy_socket::ProxySocket; -const BUFFER_SIZE: usize = 1024 ^ 2; // 1024 ^ 2 is the maximum - -fn valid_length(length: usize) -> bool { - length > 0 && length <= BUFFER_SIZE -} - -// Read a header (message size) from stdin (e.g.: from the browser). -fn read_header() -> Result { - let stdin = stdin(); - let mut buf = vec![0; 4]; - let mut handle = stdin.lock(); - - handle.read_exact(&mut buf)?; - - NativeEndian::read_u32(&buf) - .try_into() - .map_err(|err| Error::new(ErrorKind::InvalidData, err)) -} - -// Handle a whole request/response cycle +// > The maximum size of a single message from the application is 1 MB. // -// Read a message body from stdin (e.g.: from the browser), and echo it back to the browser's -// socket. Then await a response from the socket and relay that back to the browser. -fn read_body(length: usize, socket: &mut ProxySocket) -> Result<()> { - let mut buffer = vec![0; length]; - let stdin = stdin(); - let mut handle = stdin.lock(); +// From: https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Native_messaging#app_side +const BUFFER_SIZE: usize = 1024 * 1024; - handle.read_exact(&mut buffer)?; +/// Reads from stdin and writes to the socket. +/// Returns on error. +fn stdin_to_socket(socket: &mut ProxySocket) -> Result<()> { + let mut handle = stdin().lock(); + let mut len = vec![0; std::mem::size_of::()]; + + loop { + handle.read_exact(&mut len)?; + let length: usize = NativeEndian::read_u32(&len) + .try_into() + .map_err(|err| Error::new(ErrorKind::InvalidData, err))?; + + let mut buffer = vec![0; length]; + handle.read_exact(&mut buffer)?; - if valid_length(length) { socket.write_all(&buffer)?; socket.flush()?; - read_response(socket)?; } - - Ok(()) } -// Read a response (from KP's socket) and echo it back to the browser. -fn read_response(socket: &mut ProxySocket) -> Result<()>{ - let mut buf = vec![0; BUFFER_SIZE]; - if let Ok(len) = socket.read(&mut buf) { - write_response(&buf[0..len])?; +/// Reads from the socket and writes to stdout. +/// Returns on error. +fn socket_to_stdout(socket: &mut ProxySocket) -> Result<()> { + let mut out = stdout().lock(); + let mut buf = [0; BUFFER_SIZE]; + + loop { + if let Ok(len) = socket.read(&mut buf) { + // If a message is larger than the maximum, ignore it entirely. These are disallowed + // by the browser anyway, so sending one would be a protocol violation. + if len <= BUFFER_SIZE { + out.write_u32::(len as u32)?; + out.write_all(&buf[..len])?; + out.flush()?; + }; + } else { + // TOOD: is the socket is closed, we should try to reconnect. + + return Err(Error::from(ErrorKind::BrokenPipe)); + } } +} + +fn main() -> Result<()> { + let mut socket = proxy_socket::connect(BUFFER_SIZE)?; + let mut socket_clone = socket.try_clone()?; + + thread::spawn(move || socket_to_stdout(&mut socket_clone).unwrap()); + + // If stdin is closed, that means that Firefox has exited, so we exit too. + // If the socket is closed, this will (eventually) fail too, however, this can later be + // refactored to reconnect the underlying ProxySocket. + stdin_to_socket(&mut socket).unwrap(); Ok(()) } - -// Write a response to stdout (e.g.: to the browser). -fn write_response(buf: &[u8]) -> Result<()> { - let stdout = stdout(); - let mut out = stdout.lock(); - - out.write_u32::(buf.len() as u32)?; - out.write_all(buf)?; - out.flush()?; - - Ok(()) -} - -fn main() { - let mut socket = proxy_socket::connect(BUFFER_SIZE).unwrap(); - - // Start thread for user input reading - let ui = thread::spawn(move || loop { - let length = read_header().unwrap(); - read_body(length, &mut socket).unwrap(); - }); - - let _ui_res = ui.join().unwrap(); -} diff --git a/src/proxy_socket.rs b/src/proxy_socket.rs index 4e6668c..2214ec8 100644 --- a/src/proxy_socket.rs +++ b/src/proxy_socket.rs @@ -2,12 +2,14 @@ use std::env; use std::io::{self, Read, Write}; #[cfg(not(windows))] -use std::os::unix::io::AsRawFd; -use std::os::unix::net::UnixStream; -use std::path::PathBuf; -use nix::sys::socket; -use nix::sys::socket::sockopt::SndBuf; -use nix::sys::socket::sockopt::RcvBuf; +use { + nix::sys::socket::{ + setsockopt, + sockopt::{RcvBuf, SndBuf}, + }, + std::os::unix::{io::AsRawFd, net::UnixStream}, + std::path::PathBuf, +}; #[cfg(windows)] use named_pipe::PipeClient; @@ -16,6 +18,20 @@ pub struct ProxySocket { inner: T, } +#[cfg(not(windows))] +impl ProxySocket { + pub(crate) fn try_clone(&self) -> io::Result { + let inner = self.inner.try_clone()?; + Ok(Self { inner }) + } +} +#[cfg(windows)] +impl ProxySocket { + pub(crate) fn try_clone(&self) -> io::Result { + todo!(); + } +} + impl Read for ProxySocket { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.inner.read(buf) @@ -33,9 +49,12 @@ impl Write for ProxySocket { } #[cfg(windows)] -pub fn connect(buffer_size: usize) -> io::Result> { +pub fn connect(_buffer_size: usize) -> io::Result> { let username = env::var("USERNAME").unwrap(); - let pipe_name = format!("\\\\.\\pipe\\keepassxc\\{}\\org.keepassxc.KeePassXC.BrowserServer", username); + let pipe_name = format!( + "\\\\.\\pipe\\keepassxc\\{}\\org.keepassxc.KeePassXC.BrowserServer", + username + ); let client = PipeClient::connect(pipe_name)?; Ok(ProxySocket { inner: client }) } @@ -70,8 +89,6 @@ fn get_socket_dirs() -> Vec { #[cfg(not(windows))] pub fn connect(buffer_size: usize) -> io::Result> { - use std::time::Duration; - let socket_name = "org.keepassxc.KeePassXC.BrowserServer"; let dirs = get_socket_dirs(); let s = dirs @@ -79,10 +96,11 @@ pub fn connect(buffer_size: usize) -> io::Result> { .find_map(|dir| UnixStream::connect(dir.join(socket_name)).ok()) .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; - socket::setsockopt(s.as_raw_fd(), SndBuf, &buffer_size)?; - socket::setsockopt(s.as_raw_fd(), RcvBuf, &buffer_size)?; + setsockopt(s.as_raw_fd(), SndBuf, &buffer_size)?; + setsockopt(s.as_raw_fd(), RcvBuf, &buffer_size)?; + + // Make sure reads are blocking. + s.set_nonblocking(false)?; - let timeout: Option = Some(Duration::from_secs(1)); - s.set_read_timeout(timeout)?; Ok(ProxySocket { inner: s }) }