diff --git a/rinit-protos/protos/rinit.proto b/rinit-protos/protos/rinit.proto index 4325353..63e5c1d 100644 --- a/rinit-protos/protos/rinit.proto +++ b/rinit-protos/protos/rinit.proto @@ -2,7 +2,10 @@ syntax = "proto2"; package rinit.api; -message Error { required string message = 1; } +message Error { + required uint32 code = 1; + required string message = 2; +} message Env { required bytes name = 1; @@ -23,6 +26,11 @@ message Rfd { }; } +enum IfKind { + IFKIND_VPN = 0; + IFKIND_INET = 1; +} + message Request { required uint64 request_id = 1; oneof command { @@ -52,6 +60,7 @@ message Response { SyncFsResponse sync_fs = 9; NetCtlResponse net_ctl = 10; NetHostResponse net_host = 11; + ProcessDiedNotification process_died = 12; Error error = 99; } @@ -70,7 +79,7 @@ message RunProcessRequest { optional bool is_entrypoint = 8; } -message KillProcessRequest { required uint64 pid = 1; } +message KillProcessRequest { required uint64 process_id = 1; } message MountVolumeRequest { required bytes tag = 1; @@ -98,7 +107,7 @@ message NetCtlRequest { required bytes mask = 2; required bytes gateway = 3; required bytes if_addr = 4; - required uint32 if_kind = 5; + required IfKind if_kind = 5; required uint32 flags = 6; } @@ -126,3 +135,9 @@ message SyncFsResponse {} message NetCtlResponse {} message NetHostResponse {} + +message ProcessDiedNotification { + required uint64 pid = 1; + required uint32 exit_status = 2; + required uint32 reason_type = 3; +} diff --git a/rinit/src/fs.rs b/rinit/src/fs.rs index 19d0b8e..a3c7d16 100644 --- a/rinit/src/fs.rs +++ b/rinit/src/fs.rs @@ -1,4 +1,6 @@ use std::{ + fs::File, + io::{BufRead, BufReader}, os::unix::fs::{MetadataExt, PermissionsExt}, path::{Path, PathBuf}, }; @@ -344,3 +346,51 @@ pub fn mount_sysroot() -> std::io::Result<()> { Ok(()) } + +pub fn find_device_major(name: &str) -> std::io::Result { + let file = File::open("/proc/devices")?; + let reader = BufReader::new(file); + + let mut major = -1; + let mut in_character_devices = false; + + for line in reader.lines() { + let line = line?; + if line == "Character devices:" { + in_character_devices = true; + } else if line.is_empty() || line == "Block devices:" { + if in_character_devices { + break; + } + } else if in_character_devices { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() == 2 { + if let (Ok(entry_major), entry_name) = (parts[0].parse::(), parts[1]) { + if entry_name == name { + major = entry_major; + break; + } + } + } + } + } + + Ok(major) +} + +pub fn nvidia_gpu_count() -> i32 { + let mut counter = 0; + + for i in 0..256 { + let path = format!("/sys/class/drm/card{}", i); + let path = Path::new(&path); + + if !path.exists() { + break; + } + + counter = i + 1; + } + + counter +} diff --git a/rinit/src/handlers.rs b/rinit/src/handlers.rs index 8c70ba6..461bf8d 100644 --- a/rinit/src/handlers.rs +++ b/rinit/src/handlers.rs @@ -11,33 +11,19 @@ use smol::{lock::Mutex, Async}; use crate::{ die, - enums::{MessageRunProcessType, MessageType, RedirectFdDesc, RedirectFdType}, + enums::RedirectFdDesc, fs::mount_volume, - io::{ - async_read_n, async_recv_bytes, async_recv_strings_array, async_recv_u32, async_recv_u64, - async_recv_u8, async_send_response_ok, send_process_died, send_response_error, - send_response_u64, MessageHeader, - }, + io::{async_read_n, async_recv_u64}, + network::{add_network_hosts, net_if_addr, net_if_addr_to_hw_addr, net_if_hw_addr, net_route}, process::{spawn_new_process, ExitReason, NewProcessArgs, ProcessDesc}, utils::{CyclicBuffer, FdPipe, FdWrapper}, + DEV_INET, DEV_VPN, }; -async fn handle_run_process_command( - _request: &api::RunProcessRequest, - _processes: Arc>>, -) -> std::io::Result { - Ok(api::response::Command::RunProcess( - api::RunProcessResponse { process_id: 0 }, - )) -} - async fn handle_run_process( - async_fd: &mut Async, - msg_id: u64, + request: &api::RunProcessRequest, processes: Arc>>, -) -> std::io::Result<()> { - let mut done = false; - +) -> std::io::Result> { let mut new_process_args = NewProcessArgs::default(); let mut fd_desc = [ @@ -46,60 +32,63 @@ async fn handle_run_process( RedirectFdDesc::Invalid, ]; - while !done { - let cmd = async_recv_u8(async_fd).await?; - let cmd = MessageRunProcessType::from_u8(cmd); + new_process_args.bin = String::from_utf8(request.program.clone()) + .expect("Failed to convert binary name to string"); - match cmd { - MessageRunProcessType::End => { - // log::trace!(" Done"); - done = true; - } - MessageRunProcessType::Bin => { - // log::trace!(" Binary"); - let bin = async_recv_bytes(async_fd).await?; - new_process_args.bin = - String::from_utf8(bin).expect("Failed to convert binary name to string"); - log::trace!(" Binary: {}", new_process_args.bin); - } - MessageRunProcessType::Arg => { - // log::trace!(" Arg"); - new_process_args.args = async_recv_strings_array(async_fd).await?; - log::trace!(" Args: {:?}", new_process_args.args); - } - MessageRunProcessType::Env => { - // log::trace!(" Env"); - new_process_args.envp = async_recv_strings_array(async_fd).await?; - log::trace!(" Env: {:?}", new_process_args.envp); - } - MessageRunProcessType::Uid => { - // log::trace!(" Uid"); - new_process_args.uid = Some(Uid::from_raw(async_recv_u32(async_fd).await?)); - log::trace!(" Uid: {:?}", new_process_args.uid); - } - MessageRunProcessType::Gid => { - // log::trace!(" Gid"); - new_process_args.gid = Some(Gid::from_raw(async_recv_u32(async_fd).await?)); - log::trace!(" Gid: {:?}", new_process_args.gid); - } - MessageRunProcessType::Rfd => { - // log::trace!(" Rfd"); - parse_fd_redit(async_fd, &mut fd_desc).await?; - log::trace!(" Rfd: {:?}", fd_desc); - } - MessageRunProcessType::Cwd => { - // log::trace!(" Cwd"); - let buf = async_recv_bytes(async_fd).await?; - new_process_args.cwd = - String::from_utf8(buf).expect("Failed to convert cwd to string"); - log::trace!(" Cwd: {}", new_process_args.cwd); - } - MessageRunProcessType::Ent => { - // log::trace!(" Ent"); - log::trace!(" Entrypoint -> true"); - new_process_args.is_entrypoint = true; - } + new_process_args.args = request + .args + .iter() + .map(|arg| String::from_utf8(arg.clone()).unwrap()) + .collect(); + + new_process_args.envp = request + .env + .iter() + .map(|env| String::from_utf8(env.clone()).unwrap()) + .collect(); + + new_process_args.uid = Some(Uid::from_raw(request.uid())); + new_process_args.gid = Some(Gid::from_raw(request.gid())); + + for rfd in &request.rfd { + let fd = rfd.fd; + + if fd >= 3 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid input", + )); } + + let redir_desc = if let Some(redirect) = rfd.redirect.as_ref() { + match redirect { + api::rfd::Redirect::Path(redir_path) => RedirectFdDesc::File( + String::from_utf8(redir_path.clone()) + .expect("Failed to convert path to string"), + ), + api::rfd::Redirect::PipeBlocking(size) => RedirectFdDesc::PipeBlocking(FdPipe { + cyclic_buffer: CyclicBuffer::new(*size as usize), + fds: [None, None], + }), + api::rfd::Redirect::PipeCyclic(size) => RedirectFdDesc::PipeCyclic(FdPipe { + cyclic_buffer: CyclicBuffer::new(*size as usize), + fds: [None, None], + }), + } + } else { + RedirectFdDesc::Invalid + }; + + fd_desc[fd as usize] = redir_desc; + } + + if let Some(work_dir) = &request.work_dir { + new_process_args.cwd = + String::from_utf8(work_dir.clone()).expect("Failed to convert cwd to string"); + } + + if let Some(is_entrypoint) = request.is_entrypoint { + new_process_args.is_entrypoint = is_entrypoint; } log::info!( @@ -114,99 +103,75 @@ async fn handle_run_process( ); if new_process_args.bin.is_empty() || new_process_args.args.is_empty() { - send_response_error(async_fd, msg_id, libc::EFAULT as i32).await?; - return Ok(()); + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid input", + )); } match spawn_new_process(new_process_args, fd_desc, processes).await { - Ok(proc_id) => { - log::info!(" Process spawned: {}", proc_id); - send_response_u64(async_fd, msg_id, proc_id).await? - } - Err(e) => { - log::error!(" Failed to spawn process: {:?}", e); - send_response_error(async_fd, msg_id, e.raw_os_error().unwrap_or(libc::EIO)).await? - } + Ok(process_id) => Ok(Some(api::response::Command::RunProcess( + api::RunProcessResponse { process_id }, + ))), + Err(e) => Err(e), } - - Ok(()) } -async fn parse_fd_redit( - async_fd: &mut Async, - fd_desc: &mut [RedirectFdDesc; 3], -) -> std::io::Result<()> { - let fd = async_recv_u32(async_fd).await?; - - let redir_type_u8 = async_recv_u8(async_fd).await?; +async fn handle_kill_process( + request: &api::KillProcessRequest, + processes: Arc>>, +) -> std::io::Result> { + let process_id = request.process_id; - let redir_type = RedirectFdType::from_u8(redir_type_u8); + if process_id == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid input", + )); + } - log::trace!( - " Parsing fd: {} redirect: {} -> {:?}", - fd, - redir_type_u8, - redir_type - ); + let mut processes = processes.lock().await; - let mut path = String::new(); - let mut cyclic_buffer_size = 0; + let mut i = 0; + let mut found = false; - match redir_type { - RedirectFdType::File => { - log::trace!(" File"); - let buf = async_recv_bytes(async_fd).await?; - path = String::from_utf8(buf).expect("Failed to convert path to string"); - log::trace!(" Path: {}", path); - } - RedirectFdType::PipeBlocking => { - log::trace!(" Pipe Blocking"); - cyclic_buffer_size = async_recv_u64(async_fd).await?; - log::trace!(" Pipe: {}", cyclic_buffer_size); - } - RedirectFdType::PipeCyclic => { - log::trace!(" Pipe Cyclic"); - cyclic_buffer_size = async_recv_u64(async_fd).await?; - log::trace!(" Pipe: {}", cyclic_buffer_size); - } - RedirectFdType::Invalid => { - log::trace!(" Invalid"); - todo!(); + while i < processes.len() { + if processes[i].id == process_id { + found = true; + break; } + + i += 1; } - if fd >= 3 { + if found { + let process = &mut processes[i]; + + if process.is_alive { + log::info!("Killing process: {}", process_id); + + process.child.kill()? + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Process is already dead", + )); + } + } else { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, - "Invalid input", + "Process not found", )); } - fd_desc[fd as usize] = match redir_type { - RedirectFdType::File => RedirectFdDesc::File(path), - RedirectFdType::PipeBlocking => { - let fd_pipe = FdPipe { - cyclic_buffer: CyclicBuffer::new(cyclic_buffer_size as usize), - fds: [None, None], - }; - RedirectFdDesc::PipeBlocking(fd_pipe) - } - RedirectFdType::PipeCyclic => { - let fd_pipe = FdPipe { - cyclic_buffer: CyclicBuffer::new(cyclic_buffer_size as usize), - fds: [None, None], - }; - RedirectFdDesc::PipeCyclic(fd_pipe) - } - RedirectFdType::Invalid => RedirectFdDesc::Invalid, - }; - - Ok(()) + Ok(Some(api::response::Command::KillProcess( + api::KillProcessResponse {}, + ))) } -async fn handle_mount_command( +async fn handle_mount( request: &api::MountVolumeRequest, -) -> std::io::Result { +) -> std::io::Result> { let tag = String::from_utf8(request.tag.clone()).map_err(|_| { std::io::Error::new( std::io::ErrorKind::InvalidInput, @@ -229,88 +194,151 @@ async fn handle_mount_command( mount_volume(tag, path)?; - Ok(api::response::Command::MountVolume( + Ok(Some(api::response::Command::MountVolume( api::MountVolumeResponse {}, - )) + ))) } -async fn handle_mount(async_fd: &mut Async, message_id: u64) -> std::io::Result<()> { - let mut done = false; +async fn handle_net_ctl( + request: &api::NetCtlRequest, +) -> std::io::Result> { + let address = String::from_utf8(request.addr.clone()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert address to string", + ) + })?; + + let mask = String::from_utf8(request.mask.clone()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert mask to string", + ) + })?; - let mut tag = String::new(); - let mut path = String::new(); + let gateway = String::from_utf8(request.gateway.clone()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert gateway to string", + ) + })?; - while !done { - let cmd = async_recv_u8(async_fd).await?; + let if_addr = String::from_utf8(request.if_addr.clone()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert interface address to string", + ) + })?; - match cmd { - // VOLUME_END - 0 => { - // log::trace!(" Done"); - done = true; - } - // VOLUME_TAG - 1 => { - // log::trace!(" Volume tag"); - let buf = async_recv_bytes(async_fd).await?; - tag = String::from_utf8(buf).expect("Failed to convert tag to string"); - log::trace!(" Tag: {}", tag); - } - // VOLUME_PATH - 2 => { - // log::trace!(" Volume path"); - let buf = async_recv_bytes(async_fd).await?; - path = String::from_utf8(buf).expect("Failed to convert path to string"); - log::trace!(" Path: {}", path); + let if_kind = api::IfKind::try_from(request.if_kind).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid IfKind value") + })?; + + let if_name = match if_kind { + api::IfKind::IfkindVpn => DEV_VPN, + api::IfKind::IfkindInet => DEV_INET, + }; + + if !if_addr.is_empty() { + log::info!("Configuring '{}' with IP: {}", if_name, if_addr); + + if if_addr.contains(":") { + // TODO(aljen): Handle IPV6 + } else { + if mask.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Invalid input", + )); } - _ => { - log::trace!(" Unknown command"); - send_response_error(async_fd, message_id, libc::EPROTONOSUPPORT as i32).await?; + + net_if_addr(if_name, &if_addr, &mask)?; + + let hw_addr = net_if_addr_to_hw_addr(&if_addr); + + let result = net_if_hw_addr(if_name, &hw_addr)?; + + if result != 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to set HW address", + )); } } } - if tag.is_empty() || path.is_empty() || !path.starts_with("/") { - send_response_error(async_fd, message_id, libc::EINVAL as i32).await?; - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Invalid input", - )); - } + if !gateway.is_empty() { + log::info!("Configuring '{}' with gateway: {}", if_name, gateway); - let result = mount_volume(tag, path); - match result { - Ok(_) => (), - Err(e) => { - send_response_error(async_fd, message_id, e.raw_os_error().unwrap_or(libc::EIO)) - .await?; - return Err(e); + if gateway.contains(":") { + // TODO(aljen): Handle IPV6 + } else { + let address = if !address.is_empty() { + Some(address.as_str()) + } else { + None + }; + let mask = if !mask.is_empty() { + Some(mask.as_str()) + } else { + None + }; + let result = net_route(if_name, address, mask, &gateway)?; + + if result != 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to set route", + )); + } } } - async_send_response_ok(async_fd, message_id).await?; - - Ok(()) + Ok(Some(api::response::Command::NetCtl(api::NetCtlResponse {}))) } -async fn handle_quit(async_fd: &mut Async, message_id: u64) -> std::io::Result<()> { - log::info!("Quitting..."); - - async_send_response_ok(async_fd, message_id).await?; - - die!("Exit"); +async fn handle_net_host( + request: &api::NetHostRequest, +) -> std::io::Result> { + let hosts: Vec<(String, String)> = request + .hosts + .iter() + .map(|host| { + let ip = String::from_utf8(host.ip.clone()) + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert IP to string", + ) + }) + .unwrap_or_default(); + + let hostname = String::from_utf8(host.hostname.clone()) + .map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Failed to convert hostname to string", + ) + }) + .unwrap_or_default(); + + (ip, hostname) + }) + .collect(); + + add_network_hosts(&hosts)?; + + Ok(Some(api::response::Command::NetHost( + api::NetHostResponse {}, + ))) } -async fn handle_quit_command( +async fn handle_quit( _request: &api::QuitRequest, -) -> std::io::Result { +) -> std::io::Result> { log::info!("Quitting..."); - // async_send_response_ok(async_fd, message_id).await?; - - // die!("Exit"); - - Ok(api::response::Command::Quit(api::QuitResponse {})) + Ok(Some(api::response::Command::Quit(api::QuitResponse {}))) } fn encode_status(status: i32, reason_type: i32) -> ExitReason { @@ -331,17 +359,22 @@ fn encode_status(status: i32, reason_type: i32) -> ExitReason { } pub async fn handle_sigchld( - async_fd: Arc>>, + async_sig_fd: Arc>>, processes: Arc>>, -) -> std::io::Result<()> { + request_id: &mut u64, +) -> std::io::Result> { let mut siginfo: libc::siginfo_t = unsafe { std::mem::zeroed() }; let mut buf = [0u8; std::mem::size_of::()]; - log::info!("Handling SIGCHLD"); + log::info!("handle_sigchld start"); - let mut async_fd = async_fd.lock().await; + log::info!("locking async_sig_fd"); + let mut async_sig_fd = async_sig_fd.lock().await; - let size = async_read_n(&mut async_fd, &mut buf).await?; + *request_id = 0; + + log::info!("Reading from async_sig_fd"); + let size = async_read_n(&mut async_sig_fd, &mut buf).await?; if size != std::mem::size_of::() { log::error!( @@ -364,7 +397,7 @@ pub async fn handle_sigchld( let child_pid = siginfo._pad[0]; if child_pid == -1 { log::error!("Zombie process with PID -1"); - return Ok(()); + return Ok(None); } let child_pid = Pid::from_raw(child_pid); @@ -374,7 +407,7 @@ pub async fn handle_sigchld( && siginfo.si_code != libc::CLD_DUMPED { log::error!("Child did not exit normally: {}", siginfo.si_code); - return Ok(()); + return Ok(None); } let wait_status = waitpid(child_pid, Some(WaitPidFlag::WNOHANG))?; @@ -385,7 +418,7 @@ pub async fn handle_sigchld( if pid != child_pid { log::error!("Expected PID {}, but got {}", child_pid, pid); - return Ok(()); + return Ok(None); } } @@ -415,119 +448,83 @@ pub async fn handle_sigchld( let exit_reason = encode_status(siginfo._pad[7], siginfo.si_code); println!("Exit reason: {:?}", exit_reason); - send_process_died(&mut async_fd, proc_id, exit_reason).await?; + let response = api::response::Command::ProcessDied(api::ProcessDiedNotification { + pid: proc_id, + exit_status: exit_reason.status as u32, + reason_type: exit_reason.reason_type as u32, + }); - Ok(()) + Ok(Some(response)) } pub async fn handle_message( - async_fd: Arc>>, + async_cmds_fd: Arc>>, processes: Arc>>, -) -> std::io::Result<()> { - let mut async_fd = async_fd.lock().await; + request_id: &mut u64, +) -> std::io::Result> { + let mut async_fd = async_cmds_fd.lock().await; let size = async_recv_u64(&mut async_fd).await? as usize; let mut buf = vec![0u8; size]; - println!("Reading message of size: {}", size); - async_read_n(&mut async_fd, &mut buf).await?; let request = api::Request::decode(buf.as_slice())?; - println!("Request: {:?}", request); - - // let mut buf = [0u8; 9]; - // let size = async_read_n(&mut async_fd, &mut buf).await?; - - // log::info!(" Handling message: {:?}, size: {}", buf, size); - - // let msg_header = MessageHeader::from_ne_bytes(&buf); - // log::trace!(" Message header: {:?} ({})", msg_header, size); - - // let message_type = MessageType::from_u8(msg_header.msg_type); + *request_id = request.request_id; - let response = match request.command { + match request.command { Some(api::request::Command::Quit(quit)) => { log::trace!(" Quit message"); - handle_quit_command(&quit).await + handle_quit(&quit).await } Some(api::request::Command::RunProcess(run_process)) => { log::trace!(" Run process message"); - handle_run_process_command(&run_process, processes).await + handle_run_process(&run_process, processes).await + } + Some(api::request::Command::KillProcess(kill_process)) => { + log::trace!(" Kill process message"); + handle_kill_process(&kill_process, processes).await } Some(api::request::Command::MountVolume(mount_volume)) => { log::trace!(" Mount volume message"); - handle_mount_command(&mount_volume).await + handle_mount(&mount_volume).await } - _ => { - die!(" Unknown message type"); + Some(api::request::Command::QueryOutput(_query_output)) => { + log::trace!(" Query output message"); + unimplemented!(); } - }; - - match response { - Ok(response) => { - let response = api::Response { - request_id: 0, - command: Some(response), - }; - - // let mut buf = Vec::new(); - // response.encode(&mut buf)?; - - async_send_response_ok(&mut async_fd, 0).await?; - // async_fd.write_all(&buf).await?; + Some(api::request::Command::NetCtl(net_ctl)) => { + log::trace!(" Net control message"); + handle_net_ctl(&net_ctl).await } - Err(e) => { - log::error!("Failed to handle message: {:?}", e); - send_response_error(&mut async_fd, 0, e.raw_os_error().unwrap_or(libc::EIO)).await?; + Some(api::request::Command::NetHost(net_host)) => { + log::trace!(" Net host message"); + handle_net_host(&net_host).await + } + Some(api::request::Command::UploadFile(_upload_file)) => { + log::trace!(" Upload file message"); + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Not implemented", + )) + } + Some(api::request::Command::PutInput(_put_input)) => { + log::trace!(" Put input message"); + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Not implemented", + )) + } + Some(api::request::Command::SyncFs(_sync_fs)) => { + log::trace!(" Sync fs message"); + Err(std::io::Error::new( + std::io::ErrorKind::Unsupported, + "Not implemented", + )) + } + _ => { + die!(" Unknown message type"); } } - - // match message_type { - // MessageType::Quit => { - // log::trace!(" Quit message"); - // handle_quit(&mut async_fd, msg_header.msg_id).await?; - // } - // MessageType::RunProcess => { - // log::trace!(" Run process message"); - // handle_run_process(&mut async_fd, msg_header.msg_id, processes).await?; - // } - // MessageType::KillProcess => { - // log::trace!(" Kill process message"); - // } - // MessageType::MountVolume => { - // log::trace!(" Mount volume message"); - // handle_mount(&mut async_fd, msg_header.msg_id).await?; - // } - // MessageType::UploadFile => { - // log::trace!(" Upload file message"); - // send_response_error( - // &mut async_fd, - // msg_header.msg_id, - // libc::EPROTONOSUPPORT as i32, - // ) - // .await?; - // } - // MessageType::QueryOutput => { - // log::trace!(" Query output message"); - // } - // MessageType::PutInput => { - // log::trace!(" Put input message"); - // } - // MessageType::SyncFs => { - // log::trace!(" Sync fs message"); - // } - // MessageType::NetCtl => { - // log::trace!(" Net control message"); - // } - // MessageType::NetHost => { - // log::trace!(" Net host message"); - // } - // _ => { - // die!(" Unknown message type"); - // } - // } - - Ok(()) } diff --git a/rinit/src/io.rs b/rinit/src/io.rs index 7c657ae..24059e4 100644 --- a/rinit/src/io.rs +++ b/rinit/src/io.rs @@ -1,47 +1,18 @@ use std::io::{Read, Write}; -use prost::Message; -use rinit_protos::rinit::api; use smol::Async; -use crate::{die, enums::Response, process::ExitReason, utils::FdWrapper}; - -#[derive(Debug)] -pub struct MessageHeader { - pub msg_id: u64, - pub msg_type: u8, -} - -impl MessageHeader { - pub fn from_ne_bytes(buf: &[u8]) -> Self { - Self { - msg_id: u64::from_le_bytes(buf[0..8].try_into().unwrap()), - msg_type: buf[8], - } - } - - pub fn to_ne_bytes(&self) -> [u8; 9] { - let mut buf = [0u8; 9]; - buf[0..8].copy_from_slice(&self.msg_id.to_le_bytes()); - buf[8] = self.msg_type; - buf - } -} +use crate::{die, utils::FdWrapper}; pub async fn async_read_n( async_fd: &mut Async, buf: &mut [u8], ) -> std::io::Result { + println!("async_read_n: {}", buf.len()); let mut total = 0; while total < buf.len() { - let n = unsafe { - async_fd.read_with_mut(|fd| { - let mut inner_buf = &mut buf[total..]; - fd.read(&mut inner_buf) - }) - } - .await?; + let n = unsafe { async_fd.read_with_mut(|fd| fd.read(&mut buf[total..])) }.await?; if n == 0 { log::info!("Waiting for host connection..."); std::thread::sleep(std::time::Duration::from_millis(1000)); @@ -70,74 +41,14 @@ pub async fn async_write_n(async_fd: &mut Async, buf: &[u8]) -> std:: Ok(total) } -pub async fn async_write_u8(async_fd: &mut Async, value: u8) -> std::io::Result { - let buf = [value]; - async_write_n(async_fd, &buf).await -} - -pub async fn async_write_u32( - async_fd: &mut Async, - value: u32, -) -> std::io::Result { - let buf = value.to_ne_bytes(); - async_write_n(async_fd, &buf).await -} - pub async fn async_write_u64( async_fd: &mut Async, value: u64, ) -> std::io::Result { - let buf = value.to_ne_bytes(); + let buf = value.to_le_bytes(); async_write_n(async_fd, &buf).await } -pub async fn async_recv_bytes(async_fd: &mut Async) -> std::io::Result> { - let size = async_recv_u64(async_fd).await?; - - let mut buf = vec![0u8; size as usize]; - async_read_n(async_fd, &mut buf[0..size as usize]).await?; - - Ok(buf) -} - -pub async fn async_recv_strings_array( - async_fd: &mut Async, -) -> std::io::Result> { - let size = async_recv_u64(async_fd).await?; - - let mut strings = Vec::with_capacity(size as usize); - - for _ in 0..size { - let string = async_recv_bytes(async_fd).await?; - let string = String::from_utf8(string).expect("Failed to convert bytes to string"); - strings.push(string); - } - - Ok(strings) -} - -pub async fn async_recv_u8(async_fd: &mut Async) -> std::io::Result { - let mut buf = [0u8; 1]; - let result = async_read_n(async_fd, &mut buf).await?; - - if result < 1 { - die!("Failed to read u8"); - } - - Ok(buf[0]) -} - -pub async fn async_recv_u32(async_fd: &mut Async) -> std::io::Result { - let mut buf = [0u8; 4]; - let result = async_read_n(async_fd, &mut buf).await?; - - if result < 4 { - die!("Failed to read u32"); - } - - Ok(u32::from_ne_bytes(buf)) -} - pub async fn async_recv_u64(async_fd: &mut Async) -> std::io::Result { let mut buf = [0u8; 8]; let result = async_read_n(async_fd, &mut buf).await?; @@ -148,93 +59,3 @@ pub async fn async_recv_u64(async_fd: &mut Async) -> std::io::Result< Ok(u64::from_be_bytes(buf)) } - -pub async fn async_send_i32(async_fd: &mut Async, value: i32) -> std::io::Result { - let buf = value.to_ne_bytes(); - let result = async_write_n(async_fd, &buf).await?; - - Ok(result) -} - -pub async fn async_send_u64(async_fd: &mut Async, value: u64) -> std::io::Result { - let buf = value.to_ne_bytes(); - let result = async_write_n(async_fd, &buf).await?; - - Ok(result) -} - -async fn send_response_header( - async_fd: &mut Async, - message_id: u64, - msg_type: Response, -) -> std::io::Result<()> { - let header = MessageHeader { - msg_id: message_id, - msg_type: msg_type as u8, - }; - - log::trace!( - " Sending response header: {:?} ({:?})", - header, - header.to_ne_bytes(), - ); - - async_write_n(async_fd, &header.to_ne_bytes()).await?; - - Ok(()) -} - -pub async fn send_response_u64( - async_fd: &mut Async, - message_id: u64, - value: u64, -) -> std::io::Result<()> { - send_response_header(async_fd, message_id, Response::OkU64).await?; - - async_send_u64(async_fd, value).await?; - - Ok(()) -} - -pub async fn async_send_response_ok( - async_fd: &mut Async, - message_id: u64, -) -> std::io::Result<()> { - send_response_header(async_fd, message_id, Response::Ok).await -} - -pub async fn send_response_error( - async_fd: &mut Async, - msg_id: u64, - err_type: i32, -) -> std::io::Result<()> { - send_response_header(async_fd, msg_id, Response::Error).await?; - - async_send_i32(async_fd, err_type).await?; - - Ok(()) -} - -pub async fn send_process_died( - async_fd: &mut Async, - proc_id: u64, - exit_reason: ExitReason, -) -> std::io::Result<()> { - send_response_header(async_fd, 0, Response::NotifyProcessDied).await?; - async_write_u64(async_fd, proc_id).await?; - async_write_u8(async_fd, exit_reason.status).await?; - async_write_u8(async_fd, exit_reason.reason_type).await?; - - Ok(()) -} - -pub async fn read_request(async_fd: &mut Async) -> std::io::Result { - let size = async_recv_u64(async_fd).await? as usize; - - let mut buf = vec![0; size as usize]; - async_read_n(async_fd, &mut buf).await?; - - let request = api::Request::decode(buf.as_slice())?; - - Ok(request) -} diff --git a/rinit/src/kernel_modules.rs b/rinit/src/kernel_modules.rs index f1ac121..b211234 100644 --- a/rinit/src/kernel_modules.rs +++ b/rinit/src/kernel_modules.rs @@ -24,7 +24,7 @@ fn load_module(module: &str) -> std::io::Result<()> { Ok(()) } -fn load_nvidia_modules() -> std::io::Result<()> { +fn load_nvidia_modules() -> std::io::Result { let nvidia_modules = [ "i2c-core.ko", "drm_panel_orientation_quirks.ko", @@ -50,10 +50,10 @@ fn load_nvidia_modules() -> std::io::Result<()> { load_module(module)?; } - Ok(()) + Ok(true) } -pub fn load_modules() -> std::io::Result<()> { +pub fn load_modules() -> std::io::Result { let modules = [ (false, "failover.ko"), (false, "virtio.ko"), @@ -95,9 +95,11 @@ pub fn load_modules() -> std::io::Result<()> { } } - if Path::new("/nvidia.ko").exists() { - load_nvidia_modules()?; - } + let nvidia_loaded = if Path::new("/nvidia.ko").exists() { + load_nvidia_modules()? + } else { + false + }; - Ok(()) + Ok(nvidia_loaded) } diff --git a/rinit/src/main.rs b/rinit/src/main.rs index bfda4d5..ef63160 100644 --- a/rinit/src/main.rs +++ b/rinit/src/main.rs @@ -6,14 +6,18 @@ use std::sync::{atomic::AtomicU32, Arc}; use async_io::Async; use futures::{future::FutureExt, pin_mut, select}; +use io::{async_read_n, async_recv_u64, async_write_n, async_write_u64}; use libc::{mode_t, prctl, PR_SET_DUMPABLE}; use nix::fcntl::{fcntl, FcntlArg, OFlag}; use nix::sys::signal::{self, sigprocmask, SigSet}; use nix::sys::signalfd::{SfdFlags, SignalFd}; +use nix::sys::stat::{mknod, Mode, SFlag}; +use prost::Message; +use rinit_protos::rinit::api; use fs::{ - chroot_to_new_root, create_directories, create_dirs, mount_core_filesystems, mount_overlay, - mount_sysroot, + chroot_to_new_root, create_directories, create_dirs, find_device_major, mount_core_filesystems, + mount_overlay, mount_sysroot, nvidia_gpu_count, write_sys, }; use handlers::{handle_message, handle_sigchld}; use initramfs::copy_initramfs; @@ -46,8 +50,6 @@ const OUTPUT_PATH_PREFIX: &str = "/var/tmp/guest_agent_private/fds"; const NONE: Option<&'static [u8]> = None; const VPORT_CMD: &str = "/dev/vport0p1"; -const VPORT_NET: &str = "/dev/vport0p2"; -const VPORT_INET: &str = "/dev/vport0p3"; const NET_MEM_DEFAULT: usize = 1048576; const NET_MEM_MAX: usize = 2097152; @@ -78,21 +80,35 @@ async fn try_main() -> std::io::Result<()> { chroot_to_new_root()?; create_directories()?; mount_core_filesystems()?; - load_modules()?; - + let nvidia_loaded = load_modules()?; let storage = scan_storage()?; mount_overlay(&storage)?; mount_sysroot()?; - // TODO(aljen): Handle 'sandbox' environment variable - // TODO(aljen): Handle 'nvidia_loaded' + let do_sandbox = if env::args().any(|arg| arg == "sandbox=yes") { + true + } else if env::args().any(|arg| arg == "sandbox=no") { + false + } else { + nvidia_loaded + }; + + if nvidia_loaded { + setup_nvidia(do_sandbox)?; + } setup_sandbox(); setup_network()?; setup_agent_directories()?; block_signals()?; - // setup_sigfd()?; + + if do_sandbox { + write_sys("/proc/sys/net/ipv4/ip_unprivileged_port_start", 0); + write_sys("/proc/sys/user/max_user_namespaces", 1); + get_namespace_fd(); + } + main_loop().await?; stop_network()?; @@ -124,34 +140,91 @@ async fn main_loop() -> std::io::Result<()> { let processes = Arc::new(Mutex::new(Vec::new())); loop { - let cmd_future = async { - let cmd_fd = async_cmds_fd.lock().await; - cmd_fd.readable().await - } - .fuse(); + let (response, request_id) = { + let cmd_future = async { + let cmd_fd = async_cmds_fd.lock().await; + cmd_fd.readable().await + } + .fuse(); - let sig_future = async { - let sig_fd = async_sig_fd.lock().await; - sig_fd.readable().await - } - .fuse(); + let sig_future = async { + let sig_fd = async_sig_fd.lock().await; + sig_fd.readable().await + } + .fuse(); + + pin_mut!(cmd_future, sig_future); - pin_mut!(cmd_future, sig_future); + let mut request_id = 0; - select! { - _ = cmd_future => { - if let Err(e) = handle_message(async_cmds_fd.clone(), processes.clone()).await { - log::error!("Error handling command message: {}", e); + let result = select! { + _ = cmd_future => { + handle_message(async_cmds_fd.clone(), processes.clone(), &mut request_id).await } - } - _ = sig_future => { - if let Err(e) = handle_sigchld(async_sig_fd.clone(), processes.clone()).await { - log::error!("Error handling command message: {}", e); + _ = sig_future => { + handle_sigchld(async_sig_fd.clone(), processes.clone(), &mut request_id).await } + }; + + (result, request_id) + }; + + let quit = match response { + Ok(Some(command)) => { + println!("Handling response command: {:?}", command); + + let quit = matches!(command, api::response::Command::Quit(_)); + + let response = api::Response { + request_id, + command: Some(command), + }; + + let mut buf = Vec::new(); + response.encode(&mut buf)?; + + log::info!("locking async_fd"); + let mut async_fd = async_cmds_fd.lock().await; + + log::info!("sending response message"); + async_write_u64(&mut async_fd, buf.len() as u64).await?; + async_write_n(&mut async_fd, &buf).await?; + + quit + } + Err(e) => { + log::error!("Error handling command message: {}", e); + + let response = api::Response { + request_id, + command: Some(api::response::Command::Error(api::Error { + code: e.raw_os_error().unwrap_or(libc::EIO) as u32, + message: e.to_string(), + })), + }; + + let mut buf = Vec::new(); + response.encode(&mut buf)?; + + log::info!("locking async_fd"); + let mut async_fd = async_cmds_fd.lock().await; + + log::info!("sending error response message"); + async_write_u64(&mut async_fd, buf.len() as u64).await?; + async_write_n(&mut async_fd, &buf).await?; + + false } + _ => false, + }; + + if quit { + break; } } + println!("Exiting main loop"); + Ok(()) } @@ -173,6 +246,10 @@ fn block_signals() -> std::io::Result<()> { Ok(()) } +fn get_namespace_fd() { + // TODO(aljen): Use C version of this function +} + fn setup_agent_directories() -> std::io::Result<()> { let sysroot = Path::new(SYSROOT); let dir = sysroot.join(&OUTPUT_PATH_PREFIX[1..]); @@ -183,6 +260,48 @@ fn setup_agent_directories() -> std::io::Result<()> { Ok(()) } +fn setup_nvidia(do_sandbox: bool) -> std::io::Result<()> { + if !do_sandbox { + log::error!("Sandboxing is disabled, refusing to enable Nvidia GPU passthrough."); + log::error!( + "Please re-run the container with sandboxing enabled or disable GPU passthrough.\n" + ); + + die!("Nvidia GPU passthrough requires sandboxing to be enabled."); + } + + let nvidia_major = find_device_major("nvidia-frontend")?; + let nvidia_count = nvidia_gpu_count(); + + for i in 0..nvidia_count { + let path = format!("/mnt/newroot/dev/nvidia{}", i); + + mknod( + Path::new(&path), + SFlag::S_IFCHR, + Mode::from_bits(0o666 & 0o777).unwrap(), + (nvidia_major << 8 | i) as u64, + )?; + } + + mknod( + Path::new("/mnt/newroot/dev/nvidiactl"), + SFlag::S_IFCHR, + Mode::from_bits(0o666 & 0o777).unwrap(), + (nvidia_major << 8 | 255) as u64, + )?; + + let nvidia_major = find_device_major("nvidia-uvm")?; + mknod( + Path::new("/mnt/newroot/dev/nvidia-uvm"), + SFlag::S_IFCHR, + Mode::from_bits(0o666 & 0o777).unwrap(), + (nvidia_major << 8) as u64, + )?; + + Ok(()) +} + fn setup_sandbox() { #[link(name = "seccomp")] extern "C" { diff --git a/rinit/src/network.rs b/rinit/src/network.rs index ad6955b..8912ab3 100644 --- a/rinit/src/network.rs +++ b/rinit/src/network.rs @@ -26,7 +26,7 @@ pub fn stop_network() -> std::io::Result<()> { Ok(()) } -pub fn add_network_hosts(entries: &[(&str, &str)]) -> std::io::Result<()> { +pub fn add_network_hosts>(entries: &[(S, S)]) -> std::io::Result<()> { let mut f = BufWriter::new( File::options() .append(true) @@ -34,7 +34,7 @@ pub fn add_network_hosts(entries: &[(&str, &str)]) -> std::io::Result<()> { ); for entry in entries.iter() { - match f.write_fmt(format_args!("{}\t{}\n", entry.0, entry.1)) { + match f.write_fmt(format_args!("{}\t{}\n", entry.0.as_ref(), entry.1.as_ref())) { Ok(_) => (), Err(e) => return Err(std::io::Error::new(std::io::ErrorKind::Other, e)), } @@ -144,7 +144,7 @@ unsafe fn net_if_alias(ifr: &mut ifreq, name: &str) -> nix::Result { } // Function to configure the network interface address and netmask -fn net_if_addr(name: &str, ip: &str, mask: &str) -> nix::Result { +pub fn net_if_addr(name: &str, ip: &str, mask: &str) -> nix::Result { log::info!( "Setting address {} and netmask {} for interface {}", ip, @@ -249,6 +249,153 @@ fn net_if_addr(name: &str, ip: &str, mask: &str) -> nix::Result { Ok(result) } +pub fn net_if_addr_to_hw_addr(ip: &str) -> [u8; 6] { + let ip = ipv4_to_u32(ip); + + let mut hw_addr = [0u8; 6]; + hw_addr[0] = 0x90; + hw_addr[1] = 0x13; + hw_addr[2] = (ip >> 24) as u8; + hw_addr[3] = (ip >> 16) as u8; + hw_addr[4] = (ip >> 8) as u8; + hw_addr[5] = ip as u8; + + hw_addr +} + +pub fn net_if_hw_addr(name: &str, mac: &[u8; 6]) -> nix::Result { + log::info!( + "Setting hardware address {:02X}:{:02X}:{:02X}:{:02X}:{:02X}:{:02X} for interface {}", + mac[0], + mac[1], + mac[2], + mac[3], + mac[4], + mac[5], + name, + ); + + // Open a socket + let fd = socket( + AddressFamily::Packet, + SockType::Raw, + SockFlag::empty(), + None, + )?; + + // Create an empty ifreq struct + let mut ifr: ifreq = unsafe { std::mem::zeroed() }; + + let c_name = CString::new(name).unwrap(); + + // Set the interface name + unsafe { + strncpy( + ifr.ifr_name.as_mut_ptr() as *mut c_char, + c_name.as_ptr(), + ifr.ifr_name.len() - 1, + ) + }; + + // Set up the sockaddr_in structure for the address + let sa: *mut sockaddr_in = unsafe { &mut ifr.ifr_ifru.ifru_addr as *mut _ as *mut sockaddr_in }; + + // Set the interface address + unsafe { + (*sa).sin_family = AF_INET as u16; + (*sa).sin_addr.s_addr = 0; + } + + // Set the hardware address + + let hw_addr: *mut libc::sockaddr = + unsafe { &mut ifr.ifr_ifru.ifru_hwaddr as *mut libc::sockaddr }; + unsafe { + (*hw_addr).sa_family = libc::ARPHRD_ETHER as u16; + for (i, byte) in mac.iter().enumerate() { + (*hw_addr).sa_data[i] = *byte as i8; + } + } + + // Apply the hardware address + let result = unsafe { + libc::ioctl( + fd.as_raw_fd(), + libc::SIOCSIFHWADDR.try_into().unwrap(), + &mut ifr, + ) + }; + + Ok(result) +} + +pub fn net_route( + name: &str, + ip: Option<&str>, + mask: Option<&str>, + via: &str, +) -> std::io::Result { + // Open a socket + let fd = socket( + AddressFamily::Inet, + SockType::Datagram, + SockFlag::empty(), + None, + )?; + + let mut rt: libc::rtentry = unsafe { std::mem::zeroed() }; + rt.rt_flags = libc::RTF_UP | libc::RTF_GATEWAY; + + let name_cstr = CString::new(name).unwrap(); + rt.rt_dev = name_cstr.as_ptr() as *mut c_char; + + let via_addr = via.parse::().unwrap(); + let via_sockaddr = sockaddr_in { + sin_family: AF_INET as u16, + sin_port: 0, + sin_addr: libc::in_addr { + s_addr: u32::from(via_addr).to_be(), + }, + sin_zero: [0; 8], + }; + rt.rt_gateway = unsafe { std::mem::transmute(via_sockaddr) }; + + let dst_addr = if let Some(ip) = ip { + ip.parse::().unwrap() + } else { + std::net::Ipv4Addr::new(0, 0, 0, 0) + }; + let dst_sockaddr = sockaddr_in { + sin_family: AF_INET as u16, + sin_port: 0, + sin_addr: libc::in_addr { + s_addr: u32::from(dst_addr).to_be(), + }, + sin_zero: [0; 8], + }; + rt.rt_dst = unsafe { std::mem::transmute(dst_sockaddr) }; + rt.rt_metric = if ip.is_some() { 101 } else { 0 }; + + let mask_addr = if let Some(mask) = mask { + mask.parse::().unwrap() + } else { + std::net::Ipv4Addr::new(0, 0, 0, 0) + }; + let mask_sockaddr = sockaddr_in { + sin_family: AF_INET as u16, + sin_port: 0, + sin_addr: libc::in_addr { + s_addr: u32::from(mask_addr).to_be(), + }, + sin_zero: [0; 8], + }; + rt.rt_genmask = unsafe { std::mem::transmute(mask_sockaddr) }; + + let ret = unsafe { libc::ioctl(fd.as_raw_fd(), libc::SIOCADDRT.try_into().unwrap(), &mut rt) }; + + Ok(ret) +} + pub fn setup_network() -> std::io::Result<()> { let hosts = [ ("127.0.0.1", "localhost"), diff --git a/runtime/init-container/Makefile b/runtime/init-container/Makefile index 3240e16..5650011 100644 --- a/runtime/init-container/Makefile +++ b/runtime/init-container/Makefile @@ -69,13 +69,13 @@ $(SRC_DIR)/seccomp.o: $(CURDIR)/$(LIBSECCOMP_SUBMODULE)/include/seccomp.h %.o: %.c $(QUIET_CC)$(CC) $(CFLAGS) -o $@ -c $< -init: $(UNPACKED_HEADERS) $(OBJECTS) $(OBJECTS_EXT) $(CURDIR)/$(LIBSECCOMP_SUBMODULE)/src/.libs/libseccomp.a +cinit: $(UNPACKED_HEADERS) $(OBJECTS) $(OBJECTS_EXT) $(CURDIR)/$(LIBSECCOMP_SUBMODULE)/src/.libs/libseccomp.a @echo cinit $(QUIET_CC)$(CC) $(CFLAGS) -static -o $@ $(wordlist 2, $(words $^), $^) @# default musl libs on some distros have debug symbols, lets strip them (and everything else) strip $@ -rinit: $(UNPACKED_HEADERS) ../../rinit/extern-libs/libseccomp.a +init: $(UNPACKED_HEADERS) ../../rinit/extern-libs/libseccomp.a @echo init cargo build --color always --target=x86_64-unknown-linux-musl --manifest-path ../../rinit/Cargo.toml -p rinit --release --bin rinit cp ../../target/x86_64-unknown-linux-musl/release/rinit init diff --git a/runtime/src/deploy.rs b/runtime/src/deploy.rs index 980f6a3..f3a8f0f 100644 --- a/runtime/src/deploy.rs +++ b/runtime/src/deploy.rs @@ -9,6 +9,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use tokio_byteorder::LittleEndian; use uuid::Uuid; +// TODO(aljen): Cleanup this, copy structs use ya_client_model::activity::exe_script_command::VolumeMount; use ya_runtime_sdk::runtime_api::deploy::ContainerVolume; diff --git a/runtime/src/response_parser.rs b/runtime/src/response_parser.rs index 11598f6..a0ee5a8 100644 --- a/runtime/src/response_parser.rs +++ b/runtime/src/response_parser.rs @@ -1,3 +1,5 @@ +use prost::Message; +use rinit_protos::rinit::api; use std::convert::TryFrom; use std::io; use tokio::io::{AsyncRead, AsyncReadExt}; @@ -85,54 +87,45 @@ async fn recv_bytes(stream: &mut T) -> io::Result> pub async fn parse_one_response( stream: &mut T, ) -> io::Result { - let id = recv_u64(stream).await?; + let size = recv_u64(stream).await?; + println!("size: {}", size); - let typ = recv_u8(stream).await?; - match typ { - 0 => Ok(GuestAgentMessage::Response(ResponseWithId { - id, + let mut buf = vec![0u8; size as usize]; + stream.read_exact(buf.as_mut_slice()).await?; + + let response = api::Response::decode(buf.as_slice())?; + + println!("response: {:?}", response); + + match response.command { + Some(api::response::Command::Quit(_)) => Ok(GuestAgentMessage::Response(ResponseWithId { + id: response.request_id, resp: Response::Ok, })), - 1 => { - let val = recv_u64(stream).await?; - Ok(GuestAgentMessage::Response(ResponseWithId { - id, - resp: Response::OkU64(val), - })) - } - 2 => { - let buf = recv_bytes(stream).await?; + Some(api::response::Command::RunProcess(command)) => { Ok(GuestAgentMessage::Response(ResponseWithId { - id, - resp: Response::OkBytes(buf), + id: response.request_id, + resp: Response::OkU64(command.process_id), })) } - 3 => { - let code = recv_u32(stream).await?; + Some(api::response::Command::KillProcess(_command)) => todo!(), + Some(api::response::Command::MountVolume(_)) => { Ok(GuestAgentMessage::Response(ResponseWithId { - id, - resp: Response::Err(code), + id: response.request_id, + resp: Response::Ok, })) } - 4 => { - if id == 0 { - let proc_id = recv_u64(stream).await?; - let fd = recv_u32(stream).await?; - Ok(GuestAgentMessage::Notification( - Notification::OutputAvailable { id: proc_id, fd }, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid response message ID", - )) - } - } - 5 => { - if id == 0 { - let proc_id = recv_u64(stream).await?; - let status = recv_u8(stream).await?; - let type_ = ExitType::try_from(recv_u8(stream).await?)?; + Some(api::response::Command::UploadFile(_command)) => todo!(), + Some(api::response::Command::QueryOutput(_command)) => todo!(), + Some(api::response::Command::PutInput(_command)) => todo!(), + Some(api::response::Command::SyncFs(_command)) => todo!(), + Some(api::response::Command::NetCtl(_command)) => todo!(), + Some(api::response::Command::NetHost(_command)) => todo!(), + Some(api::response::Command::ProcessDied(command)) => { + if response.request_id == 0 { + let proc_id = command.pid; + let status = command.exit_status as u8; + let type_ = ExitType::try_from(command.reason_type as u8)?; Ok(GuestAgentMessage::Notification(Notification::ProcessDied { id: proc_id, reason: ExitReason { status, type_ }, @@ -144,9 +137,84 @@ pub async fn parse_one_response( )) } } - _ => Err(io::Error::new( + Some(api::response::Command::Error(command)) => { + Ok(GuestAgentMessage::Response(ResponseWithId { + id: response.request_id, + resp: Response::Err(command.code), + })) + } + None => Err(io::Error::new( io::ErrorKind::InvalidData, "Invalid response type", )), } + + // let typ = recv_u8(stream).await?; + // match typ { + // // RESP_OK + // 0 => Ok(GuestAgentMessage::Response(ResponseWithId { + // id, + // resp: Response::Ok, + // })), + // // RESP_OK_U64 + // 1 => { + // let val = recv_u64(stream).await?; + // Ok(GuestAgentMessage::Response(ResponseWithId { + // id, + // resp: Response::OkU64(val), + // })) + // } + // // RESP_OK_BYTES + // 2 => { + // let buf = recv_bytes(stream).await?; + // Ok(GuestAgentMessage::Response(ResponseWithId { + // id, + // resp: Response::OkBytes(buf), + // })) + // } + // // RESP_ERR + // 3 => { + // let code = recv_u32(stream).await?; + // Ok(GuestAgentMessage::Response(ResponseWithId { + // id, + // resp: Response::Err(code), + // })) + // } + // // RESP_NOTIFY_OUTPUT_AVAILABLE + // 4 => { + // if id == 0 { + // let proc_id = recv_u64(stream).await?; + // let fd = recv_u32(stream).await?; + // Ok(GuestAgentMessage::Notification( + // Notification::OutputAvailable { id: proc_id, fd }, + // )) + // } else { + // Err(io::Error::new( + // io::ErrorKind::InvalidData, + // "Invalid response message ID", + // )) + // } + // } + // // RESP_NOTIFY_PROCESS_DIED + // 5 => { + // if id == 0 { + // let proc_id = recv_u64(stream).await?; + // let status = recv_u8(stream).await?; + // let type_ = ExitType::try_from(recv_u8(stream).await?)?; + // Ok(GuestAgentMessage::Notification(Notification::ProcessDied { + // id: proc_id, + // reason: ExitReason { status, type_ }, + // })) + // } else { + // Err(io::Error::new( + // io::ErrorKind::InvalidData, + // "Invalid response message ID", + // )) + // } + // } + // _ => Err(io::Error::new( + // io::ErrorKind::InvalidData, + // "Invalid response type", + // )), + // } }