From 73b1229469cb1353d4cacfa1e410cc1ae8a30ae2 Mon Sep 17 00:00:00 2001 From: Aaron <10217842+byteduck@users.noreply.github.com> Date: Fri, 15 Mar 2024 00:51:09 -0700 Subject: [PATCH] Kernel: Support for sending UDP packets This kind of works. For some reason it doesn't work the first time, and works most of the time subsequent times. There might be some sort of memory corruption bug because sometimes we RSOD. Too tired to debug but this seems like a good stopping point --- kernel/api/ipv4.h | 24 +++++++++++++ kernel/api/net.h | 7 ++++ kernel/net/IPSocket.cpp | 66 +++++++++++++++++++++++++++++++---- kernel/net/IPSocket.h | 18 +++++++--- kernel/net/NetworkAdapter.cpp | 49 +++++++++++++++++++++++++- kernel/net/NetworkAdapter.h | 5 ++- kernel/net/NetworkManager.cpp | 2 +- kernel/net/Router.cpp | 2 +- kernel/net/Socket.h | 1 + kernel/net/UDPSocket.cpp | 48 ++++++++++++++++++------- kernel/net/UDPSocket.h | 3 +- kernel/syscall/socket.cpp | 20 +++++++++-- 12 files changed, 216 insertions(+), 29 deletions(-) diff --git a/kernel/api/ipv4.h b/kernel/api/ipv4.h index becf6343..cda4d62f 100644 --- a/kernel/api/ipv4.h +++ b/kernel/api/ipv4.h @@ -67,6 +67,30 @@ struct __attribute__((packed)) IPv4Packet { IPv4Address source_addr; IPv4Address dest_addr; uint8_t payload[]; + + [[nodiscard]] inline BigEndian compute_checksum() const { return __compute_checksum(this); } + + inline void set_checksum() { + checksum = 0; + checksum = compute_checksum(); + } + +private: + // Necessary to beat the alignment allegations. Scary, I know + [[nodiscard]] inline BigEndian __compute_checksum(const void* voidptr) const { + uint32_t sum = 0; + auto* ptr = (const uint16_t*) voidptr; + size_t count = sizeof(IPv4Packet); + while (count > 1) { + sum += as_big_endian(*ptr++); + if (sum & 0x80000000) + sum = (sum & 0xffff) | (sum >> 16); + count -= 2; + } + while (sum >> 16) + sum = (sum & 0xffff) + (sum >> 16); + return ~sum & 0xffff; + } }; static_assert(sizeof(IPv4Packet) == 20); diff --git a/kernel/api/net.h b/kernel/api/net.h index 66f3ce3a..04451be0 100644 --- a/kernel/api/net.h +++ b/kernel/api/net.h @@ -35,6 +35,13 @@ class __attribute__((packed)) MACAddress { return false; } + inline constexpr operator bool() const { + for (auto& val : m_data) + if (val) + return true; + return false; + } + private: uint8_t m_data[6] = {0}; }; diff --git a/kernel/net/IPSocket.cpp b/kernel/net/IPSocket.cpp index b62edea2..e6fc4773 100644 --- a/kernel/net/IPSocket.cpp +++ b/kernel/net/IPSocket.cpp @@ -22,6 +22,8 @@ ResultRet> IPSocket::make(Socket::Type type, int protocol) { } Result IPSocket::bind(SafePointer addr_ptr, socklen_t addrlen) { + LOCK(m_lock); + if (m_bound || addrlen != sizeof(sockaddr_in)) return Result(set_error(EINVAL)); @@ -29,8 +31,8 @@ Result IPSocket::bind(SafePointer addr_ptr, socklen_t addrlen) { if (addr.sin_family != AF_INET) return Result(set_error(EINVAL)); - m_port = from_big_endian(addr.sin_port); - m_addr = IPv4Address(from_big_endian(addr.sin_addr.s_addr)); + m_bound_port = from_big_endian(addr.sin_port); + m_bound_addr = IPv4Address(from_big_endian(addr.sin_addr.s_addr)); return do_bind(); } @@ -38,16 +40,22 @@ Result IPSocket::bind(SafePointer addr_ptr, socklen_t addrlen) { ssize_t IPSocket::recvfrom(FileDescriptor& fd, SafePointer buf, size_t len, int flags, SafePointer src_addr, SafePointer addrlen) { m_receive_queue_lock.acquire(); + // Verify addrlen ptr + if (addrlen && addrlen.get() != sizeof(sockaddr_in)) + return -set_error(EINVAL); + // Block until we have a packet to read while (m_receive_queue.empty()) { if (fd.nonblock()) { m_receive_queue_lock.release(); - return -EAGAIN; + return -set_error(EAGAIN); } update_blocker(); m_receive_queue_lock.release(); TaskManager::current_thread()->block(m_receive_blocker); + if (m_receive_blocker.was_interrupted()) + return -set_error(EINTR); m_receive_queue_lock.acquire(); } @@ -56,21 +64,67 @@ ssize_t IPSocket::recvfrom(FileDescriptor& fd, SafePointer buf, size_t update_blocker(); m_receive_queue_lock.release(); auto res = do_recv(packet, buf, len); + + // Write out addr + if (src_addr && addrlen) { + src_addr.as().set({ + AF_INET, + as_big_endian(packet->port), + packet->packet.source_addr.val() + }); + addrlen.set(sizeof(sockaddr_in)); + } + kfree(packet); return res; } +ssize_t IPSocket::sendto(FileDescriptor& fd, SafePointer buf, size_t len, int flags, SafePointer dest_addr, socklen_t addrlen) { + LOCK(m_lock); + if (dest_addr) { + if (addrlen != sizeof(sockaddr_in)) + return -set_error(EINVAL); + + auto addr = dest_addr.as().get(); + if (addr.sin_family != AF_INET) + return -set_error(EAFNOSUPPORT); + + if (m_type != Stream) { + m_dest_addr = addr.sin_addr.s_addr; + m_dest_port = from_big_endian(addr.sin_port); + } else { + // TODO: TCP. We want to use connect() for that + } + } + + if (!m_bound) { + // If we're not bound, bind to 0.0.0.0:0 + m_bound_port = 0; + m_bound_addr = {}; + auto res = do_bind(); + if (res.is_error()) + return -res.code(); + } + + // TODO: Adapter binding? + + auto send_res = do_send(buf, len); + if (send_res.is_error()) + return -send_res.code(); + return (ssize_t) send_res.value(); +} + Result IPSocket::recv_packet(const void* buf, size_t len) { LOCK(m_receive_queue_lock); if (m_receive_queue.size() == m_receive_queue.capacity()) { KLog::warn("IPSocket", "Dropping packet because receive queue is full"); - return Result(ENOSPC); + return Result(set_error(ENOSPC)); } auto* src_pkt = (const IPv4Packet*) buf; - auto* new_pkt = (IPv4Packet*) kmalloc(len); - memcpy(new_pkt, src_pkt, len); + auto* new_pkt = new RecvdPacket; + memcpy(&new_pkt->packet, src_pkt, len); m_receive_queue.push_back(new_pkt); update_blocker(); diff --git a/kernel/net/IPSocket.h b/kernel/net/IPSocket.h index a2fa2bcc..6d49489c 100644 --- a/kernel/net/IPSocket.h +++ b/kernel/net/IPSocket.h @@ -14,6 +14,7 @@ class IPSocket: public Socket { // Socket Result bind(SafePointer addr, socklen_t addrlen) override; ssize_t recvfrom(FileDescriptor &fd, SafePointer buf, size_t len, int flags, SafePointer src_addr, SafePointer addrlen) override; + ssize_t sendto(FileDescriptor &fd, SafePointer buf, size_t len, int flags, SafePointer dest_addr, socklen_t addrlen) override; Result recv_packet(const void* buf, size_t len) override; // File @@ -22,15 +23,24 @@ class IPSocket: public Socket { protected: IPSocket(Socket::Type type, int protocol); - virtual ssize_t do_recv(const IPv4Packet* pkt, SafePointer buf, size_t len) = 0; + struct RecvdPacket { + uint16_t port; + IPv4Packet packet; // Not actually set until we do do_recv + }; + + virtual ssize_t do_recv(RecvdPacket* pkt, SafePointer buf, size_t len) = 0; virtual Result do_bind() = 0; + virtual ResultRet do_send(SafePointer buf, size_t len) = 0; void update_blocker(); bool m_bound = false; - uint16_t m_port; - IPv4Address m_addr; - kstd::circular_queue m_receive_queue { 16 }; + uint16_t m_bound_port, m_dest_port; + IPv4Address m_bound_addr, m_dest_addr; + kstd::circular_queue m_receive_queue { 16 }; Mutex m_receive_queue_lock { "IPSocket::receive_queue" }; + Mutex m_lock { "IPSocket::lock" }; BooleanBlocker m_receive_blocker; + uint8_t m_type_of_service = 0; + uint8_t m_ttl = 64; }; diff --git a/kernel/net/NetworkAdapter.cpp b/kernel/net/NetworkAdapter.cpp index b2f0c2be..121df794 100644 --- a/kernel/net/NetworkAdapter.cpp +++ b/kernel/net/NetworkAdapter.cpp @@ -55,7 +55,7 @@ void NetworkAdapter::receive_bytes(SafePointer bytes, size_t count) { int i; for (i = 0; i < 32; i++) { - if (!m_packets[i].used) + if (!m_packets[i].used.load(MemoryOrder::Acquire)) break; } if (i == 32) { @@ -102,3 +102,50 @@ NetworkAdapter::Packet* NetworkAdapter::dequeue_packet() { m_packet_queue = m_packet_queue->next; return pkt; } + +NetworkAdapter::Packet* NetworkAdapter::alloc_packet(size_t size) { + ASSERT(size < 8192); + TaskManager::ScopedCritical crit; + int i; + for (i = 0; i < 32; i++) { + if (!m_packets[i].used.load(MemoryOrder::Acquire)) + break; + } + + if (i == 32) + return nullptr; + + auto& pkt = m_packets[i]; + pkt.size = sizeof(FrameHeader) + size; + return &m_packets[i]; +} + +IPv4Packet* NetworkAdapter::setup_ipv4_packet(Packet* packet, const MACAddress& dest, const IPv4Address& dest_addr, IPv4Proto proto, size_t payload_size, uint8_t dscp, uint8_t ttl) { + ASSERT(packet); + + auto* frame = (FrameHeader*) packet->buffer; + frame->type = EtherProto::IPv4; + frame->destination = dest; + frame->source = m_mac_addr; + + auto* ipv4 = (IPv4Packet*) (packet->buffer + sizeof(FrameHeader)); + ipv4->source_addr = m_ipv4_addr; + ipv4->dest_addr = dest_addr; + ipv4->length = payload_size + sizeof(IPv4Packet); + ipv4->dscp_ecn = dscp; + ipv4->ttl = ttl; + ipv4->proto = proto; + ipv4->identification = 1; + ipv4->version_ihl = (4 << 4) | 5; + ipv4->identification = 1; + ipv4->set_checksum(); + + return ipv4; +} + +void NetworkAdapter::send_packet(NetworkAdapter::Packet* packet) { + ASSERT(packet->size < 8192); + send_raw_packet(KernelPointer(packet->buffer), packet->size); + packet->used.store(false, MemoryOrder::Release); +} + diff --git a/kernel/net/NetworkAdapter.h b/kernel/net/NetworkAdapter.h index 21b578b9..c5f6f3ac 100644 --- a/kernel/net/NetworkAdapter.h +++ b/kernel/net/NetworkAdapter.h @@ -18,7 +18,7 @@ class NetworkAdapter: public kstd::ArcSelf { uint8_t buffer[8192]; /* TODO: We need non-constant packet sizes... */ union { size_t size; - bool used = false; + Atomic used = false; }; Packet* next = nullptr; }; @@ -32,7 +32,10 @@ class NetworkAdapter: public kstd::ArcSelf { void send_arp_packet(MACAddress dest, const ARPPacket& packet); void send_raw_packet(SafePointer bytes, size_t count); + void send_packet(Packet* packet); Packet* dequeue_packet(); + Packet* alloc_packet(size_t size); + IPv4Packet* setup_ipv4_packet(Packet* packet, const MACAddress& dest, const IPv4Address& dest_addr, IPv4Proto proto, size_t payload_size, uint8_t dscp, uint8_t ttl); [[nodiscard]] IPv4Address ipv4_address() const { return m_ipv4_addr; } [[nodiscard]] MACAddress mac_address() const { return m_mac_addr; } diff --git a/kernel/net/NetworkManager.cpp b/kernel/net/NetworkManager.cpp index fb8e422d..df222e19 100644 --- a/kernel/net/NetworkManager.cpp +++ b/kernel/net/NetworkManager.cpp @@ -31,7 +31,7 @@ void NetworkManager::do_task() { for (auto& iface : NetworkAdapter::interfaces()) { while ((packet = iface->dequeue_packet())) { handle_packet(iface, packet); - packet->used = false; + packet->used.store(false, MemoryOrder::Release); } } } diff --git a/kernel/net/Router.cpp b/kernel/net/Router.cpp index a8b56d3e..b9454299 100644 --- a/kernel/net/Router.cpp +++ b/kernel/net/Router.cpp @@ -102,7 +102,7 @@ Router::Route Router::get_route(const IPv4Address& dest, const IPv4Address& sour } // ARP lookup - KLog::dbg_if("Router", "Could not find route to {}, sending ARP request thru {} for {}", dest, adapter->name(), next_hop); + KLog::dbg_if("Router", "Could not find route to {}, looking up ARP entry for {}", dest, next_hop); auto mac = arp_lookup(next_hop, adapter); if (mac.is_error()) return {{}, {}}; diff --git a/kernel/net/Socket.h b/kernel/net/Socket.h index eea38b9f..f08efaec 100644 --- a/kernel/net/Socket.h +++ b/kernel/net/Socket.h @@ -26,6 +26,7 @@ class Socket: public File { // Socket virtual Result bind(SafePointer addr, socklen_t addrlen) = 0; virtual ssize_t recvfrom(FileDescriptor& fd, SafePointer buf, size_t len, int flags, SafePointer src_addr, SafePointer addrlen) = 0; + virtual ssize_t sendto(FileDescriptor& fd, SafePointer buf, size_t len, int flags, SafePointer dest_addr, socklen_t addrlen) = 0; virtual Result recv_packet(const void* buf, size_t len) = 0; [[nodiscard]] int error() const { return m_error; } diff --git a/kernel/net/UDPSocket.cpp b/kernel/net/UDPSocket.cpp index dd08499b..a8328480 100644 --- a/kernel/net/UDPSocket.cpp +++ b/kernel/net/UDPSocket.cpp @@ -4,6 +4,7 @@ #include "UDPSocket.h" #include "../kstd/KLog.h" #include "../api/udp.h" +#include "Router.h" #define UDP_DBG 1 @@ -16,9 +17,9 @@ UDPSocket::UDPSocket(): IPSocket(Type::Dgram, 0) { UDPSocket::~UDPSocket() { LOCK(s_sockets_lock); - if (m_bound && s_sockets.contains(m_port)) { - s_sockets.erase(m_port); - KLog::dbg_if("UDPSocket", "Unbinding from port {}", m_port); + if (m_bound && s_sockets.contains(m_bound_port)) { + s_sockets.erase(m_bound_port); + KLog::dbg_if("UDPSocket", "Unbinding from port {}", m_bound_port); } } @@ -40,12 +41,10 @@ Result UDPSocket::do_bind() { LOCK(s_sockets_lock); if (m_bound) return Result(set_error(EINVAL)); - if (s_sockets.contains(m_port)) + if (s_sockets.contains(m_bound_port)) return Result(set_error(EADDRINUSE)); - KLog::dbg_if("UDPSocket", "Binding to port {}", m_port); - - if (m_port == 0) { + if (m_bound_port == 0) { // If we didn't specify a port, we want an ephemeral port // (Range suggested by IANA and RFC 6335) uint16_t ephem; @@ -59,22 +58,47 @@ Result UDPSocket::do_bind() { return Result(set_error(EADDRINUSE)); } - m_port = ephem; + m_bound_port = ephem; } - s_sockets[m_port] = self(); + KLog::dbg_if("UDPSocket", "Binding to port {}", m_bound_port); + + s_sockets[m_bound_port] = self(); m_bound = true; return Result(SUCCESS); } -ssize_t UDPSocket::do_recv(const IPv4Packet* pkt, SafePointer buf, size_t len) { - auto* udp_pkt = (const UDPPacket*) pkt->payload; - ASSERT(pkt->length >= sizeof(IPv4Packet) + sizeof(UDPPacket)); // Should've been rejected at IP layer +ssize_t UDPSocket::do_recv(RecvdPacket* pkt, SafePointer buf, size_t len) { + auto* udp_pkt = (const UDPPacket*) pkt->packet.payload; + ASSERT(pkt->packet.length >= sizeof(IPv4Packet) + sizeof(UDPPacket)); // Should've been rejected at IP layer ASSERT(udp_pkt->len >= sizeof(UDPPacket)); // Should've been rejected in NetworkManager const size_t nread = min(len, udp_pkt->len.val() - sizeof(UDPPacket)); buf.write(udp_pkt->payload, nread); + KLog::dbg_if("UDPSocket", "Received packet from {}:{} ({} bytes)", pkt->packet.source_addr, udp_pkt->source_port, nread); + + pkt->port = udp_pkt->source_port; + return (ssize_t) nread; } + +ResultRet UDPSocket::do_send(SafePointer buf, size_t len) { + auto route = Router::get_route(m_dest_addr, {}, {}); + if (!route.mac || !route.adapter) + return Result(set_error(EHOSTUNREACH)); + + const size_t packet_len = sizeof(IPv4Packet) + sizeof(UDPPacket) + len; + auto pkt = route.adapter->alloc_packet(packet_len); + auto* ipv4_packet = route.adapter->setup_ipv4_packet(pkt, route.mac, m_dest_addr, UDP, sizeof(UDPPacket) + len, m_type_of_service, m_ttl); + auto* udp_packet = (UDPPacket*) ipv4_packet->payload; + udp_packet->source_port = m_bound_port; + udp_packet->dest_port = m_dest_port; + udp_packet->len = sizeof(UDPPacket) + len; + buf.read(udp_packet->payload, len); + + KLog::dbg_if("UDPSocket", "Sending packet to {}:{} ({} bytes)", m_dest_addr, m_dest_port, len); + route.adapter->send_packet(pkt); + return len; +} diff --git a/kernel/net/UDPSocket.h b/kernel/net/UDPSocket.h index 7e602967..a4ec2a1f 100644 --- a/kernel/net/UDPSocket.h +++ b/kernel/net/UDPSocket.h @@ -17,7 +17,8 @@ class UDPSocket: public IPSocket, public kstd::ArcSelf { UDPSocket(); Result do_bind() override; - ssize_t do_recv(const IPv4Packet *pkt, SafePointer buf, size_t len) override; + ssize_t do_recv(RecvdPacket* pkt, SafePointer buf, size_t len) override; + ResultRet do_send(SafePointer buf, size_t len) override; static kstd::map> s_sockets; static Mutex s_sockets_lock; diff --git a/kernel/syscall/socket.cpp b/kernel/syscall/socket.cpp index 06568b7c..fe52ff1b 100644 --- a/kernel/syscall/socket.cpp +++ b/kernel/syscall/socket.cpp @@ -62,6 +62,22 @@ int Process::sys_recvmsg(int sockfd, UserspacePointer msg_ptr, in return socket->recvfrom(*desc, buf, iov.iov_len, flags, addr_ptr, addrlen_ptr); } -int Process::sys_sendmsg(int sockfd, UserspacePointer msg, int flags) { - return -1; +int Process::sys_sendmsg(int sockfd, UserspacePointer msg_ptr, int flags) { + get_socket(sockfd); + + auto msg = msg_ptr.get(); + + // TODO: More than one entry in iovec + if (msg.msg_iovlen != 1) + return -EINVAL; + + auto iov = UserspacePointer(msg.msg_iov).get(); + auto addr_ptr = UserspacePointer((sockaddr*) msg.msg_name); + auto addrlen_ptr = UserspacePointer(&msg_ptr.raw()->msg_namelen); + auto buf = UserspacePointer((uint8_t*) iov.iov_base); + + // TODO: Control messages + UserspacePointer(&msg_ptr.raw()->msg_controllen).set(0); + + return socket->sendto(*desc, buf, iov.iov_len, flags, addr_ptr, addrlen_ptr.get()); } \ No newline at end of file