Skip to content

Commit

Permalink
Create a socket per user
Browse files Browse the repository at this point in the history
  • Loading branch information
pka committed Oct 12, 2024
1 parent fbf714c commit 84401e7
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/bin/shell_compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl DispatcherProc {
}
fn wait(&self, max_ms: u64) -> Result<(), DispatcherError> {
let mut wait_ms = 0;
while IpcStream::check_connection(SOCKET_NAME).is_err() {
while IpcStream::check_connection().is_err() {
if wait_ms >= max_ms {
return Err(DispatcherError::ProcSpawnTimeoutError);
}
Expand All @@ -68,7 +68,7 @@ fn cli() -> Result<(), DispatcherError> {

init_cli_logger();

if IpcStream::check_connection(SOCKET_NAME).is_err() {
if IpcStream::check_connection().is_err() {
if matches!(cli_command, Ok(CliCommand::Exit)) {
// Background process already exited
return Ok(());
Expand All @@ -78,7 +78,7 @@ fn cli() -> Result<(), DispatcherError> {
dispatcher.wait(2000)?;
}

let mut stream = IpcStream::connect("cli", SOCKET_NAME)?;
let mut stream = IpcStream::connect("cli")?;
let msg: Message = exec_command
.map(Into::into)
.or_else(|_| cli_command.map(Into::into))?;
Expand Down
11 changes: 9 additions & 2 deletions src/bin/shell_composed.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use clap::{CommandFactory, FromArgMatches, Subcommand};
use log::error;
use shell_compose::{
init_daemon_logger, start_ipc_listener, Cli, Dispatcher, ExecCommand, Message, SOCKET_NAME,
init_daemon_logger, start_ipc_listener, Cli, Dispatcher, ExecCommand, IpcStream, Message,
};
use std::fs::remove_file;

fn run_server() {
let cli = Cli::command();
Expand All @@ -20,8 +21,14 @@ fn run_server() {
dispatcher.exec_command(cmd);
}

let socket_name = IpcStream::user_socket_name();
// reclaim_name in interprocess::local_socket::ListenerOptions
// does not work, so we delete the socket first.
if IpcStream::check_connection().is_err() {
remove_file(&socket_name).ok();
}
start_ipc_listener(
SOCKET_NAME,
&socket_name,
move |mut stream| {
let Ok(_connect) = stream.receive_message() else {
return;
Expand Down
48 changes: 39 additions & 9 deletions src/ipc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::Message;
use crate::{get_user_name, Message};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use interprocess::local_socket::{prelude::*, GenericNamespaced, ListenerOptions};
use interprocess::local_socket::{prelude::*, GenericFilePath, ListenerOptions};
use log::debug;
use std::io;
use std::io::prelude::*;
Expand Down Expand Up @@ -55,9 +55,22 @@ pub fn start_ipc_listener<F: FnMut(IpcStream) + Send + 'static>(
on_connection_error: Option<fn(io::Error)>,
) -> Result<(), IpcServerError> {
let name = socket
.to_ns_name::<GenericNamespaced>()
.to_fs_name::<GenericFilePath>()
.map_err(IpcServerError::SocketNameError)?;
let listener = match ListenerOptions::new().name(name.clone()).create_sync() {
let mut options = ListenerOptions::new().name(name.clone());
#[cfg(target_family = "unix")]
{
use interprocess::os::unix::local_socket::ListenerOptionsExt;
options = options.mode(0o600);
}
#[cfg(target_family = "windows")]
{
use interprocess::os::windows::{
local_socket::ListenerOptionsExt, security_descriptor::SecurityDescriptor,
};
options = options.security_descriptor(SecurityDescriptor::new().unwrap());
}
let listener = match options.create_sync() {
Err(e) => return Err(IpcServerError::BindError(e)),
Ok(listener) => listener,
};
Expand All @@ -84,7 +97,7 @@ pub fn start_ipc_listener<F: FnMut(IpcStream) + Send + 'static>(
/// Connect to the socket and return the stream.
fn ipc_client_connect(socket_name: &str) -> Result<LocalSocketStream, IpcClientError> {
let name = socket_name
.to_ns_name::<GenericNamespaced>()
.to_fs_name::<GenericFilePath>()
.map_err(IpcClientError::SocketNameError)?;
LocalSocketStream::connect(name).map_err(IpcClientError::ConnectError)
}
Expand Down Expand Up @@ -135,19 +148,36 @@ pub struct IpcStream {

impl IpcStream {
/// Connects to the socket and return the stream
pub fn connect(logname: &str, socket_name: &str) -> Result<Self, IpcClientError> {
let mut stream = ipc_client_connect(socket_name)?;
pub fn connect(logname: &str) -> Result<Self, IpcClientError> {
let socket_name = IpcStream::user_socket_name();
let mut stream = ipc_client_connect(&socket_name)?;
stream.write_serde(&Message::Connect)?;
Ok(IpcStream {
logname: logname.to_string(),
stream,
})
}
/// Check socket connection
pub fn check_connection(socket_name: &str) -> Result<(), IpcClientError> {
IpcStream::connect("check_connection", socket_name)?;
pub fn check_connection() -> Result<(), IpcClientError> {
IpcStream::connect("check_connection")?;
Ok(())
}
pub fn user_socket_name() -> String {
let user = get_user_name().unwrap_or("_".to_string());
IpcStream::socket_name(&user)
}
#[cfg(target_family = "unix")]
fn socket_name(user: &str) -> String {
let tmpdir = std::env::var("TMPDIR").ok();
format!(
"{}/shell-compose-{user}.sock",
tmpdir.as_deref().unwrap_or("/tmp")
)
}
#[cfg(target_family = "windows")]
fn socket_name(user: &str) -> String {
format!(r"\\.\pipe\shell-compose-{user}")
}
/// Check stream
pub fn alive(&mut self) -> Result<(), IpcClientError> {
self.stream.write_serde(&Message::Connect)?;
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@ pub use display::*;
pub use ipc::*;
pub use justfile::*;
pub use runner::*;

pub const SOCKET_NAME: &str = "shell-compose.sock";
22 changes: 20 additions & 2 deletions src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use log::info;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::io::{BufRead, BufReader, Read};
use std::process::{Child, Command, Stdio};
use std::process::{self, Child, Command, Stdio};
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use sysinfo::{ProcessRefreshKind, RefreshKind, System};
use sysinfo::{ProcessRefreshKind, RefreshKind, System, UpdateKind, Users};

/// Child process controller
pub struct Runner {
Expand Down Expand Up @@ -247,3 +247,21 @@ fn output_listener<R: Read>(
channel.send(pid).unwrap();
}
}

/// Current user
pub fn get_user_name() -> Option<String> {
let system = System::new_with_specifics(
RefreshKind::new()
.with_processes(ProcessRefreshKind::new().with_user(UpdateKind::OnlyIfNotSet)),
);
let users = Users::new_with_refreshed_list();
let pid = process::id();
system
.process(sysinfo::Pid::from_u32(pid))
.and_then(|proc| {
proc.effective_user_id()
.or(proc.user_id())
.and_then(|uid| users.get_user_by_id(uid))
.map(|user| user.name().to_string())
})
}

0 comments on commit 84401e7

Please sign in to comment.