Skip to content

Commit

Permalink
LibCore: Support IPv6 for TCP and UDP connection
Browse files Browse the repository at this point in the history
  • Loading branch information
xlmnxp authored and ADKaster committed Jul 5, 2024
1 parent 47aee28 commit ab82fc8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 16 deletions.
2 changes: 2 additions & 0 deletions AK/Forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Error;
class FlyString;
class GenericLexer;
class IPv4Address;
class IPv6Address;
class JsonArray;
class JsonObject;
class JsonValue;
Expand Down Expand Up @@ -167,6 +168,7 @@ using AK::GenericLexer;
using AK::HashMap;
using AK::HashTable;
using AK::IPv4Address;
using AK::IPv6Address;
using AK::JsonArray;
using AK::JsonObject;
using AK::JsonValue;
Expand Down
42 changes: 34 additions & 8 deletions Userland/Libraries/LibCore/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
case SocketDomain::Inet:
socket_domain = AF_INET;
break;
case SocketDomain::Inet6:
socket_domain = AF_INET6;
break;
case SocketDomain::Local:
socket_domain = AF_LOCAL;
break;
Expand Down Expand Up @@ -48,7 +51,7 @@ ErrorOr<int> Socket::create_fd(SocketDomain domain, SocketType type)
#endif
}

ErrorOr<IPv4Address> Socket::resolve_host(ByteString const& host, SocketType type)
ErrorOr<Variant<IPv4Address, IPv6Address>> Socket::resolve_host(ByteString const& host, SocketType type)
{
int socket_type;
switch (type) {
Expand All @@ -71,14 +74,21 @@ ErrorOr<IPv4Address> Socket::resolve_host(ByteString const& host, SocketType typ
auto const results = TRY(Core::System::getaddrinfo(host.characters(), nullptr, hints));

for (auto const& result : results.addresses()) {
if (result.ai_family == AF_INET6) {
auto* socket_address = bit_cast<struct sockaddr_in6*>(result.ai_addr);
auto address = IPv6Address { socket_address->sin6_addr.s6_addr };

return address;
}

if (result.ai_family == AF_INET) {
auto* socket_address = bit_cast<struct sockaddr_in*>(result.ai_addr);
NetworkOrdered<u32> const network_ordered_address { socket_address->sin_addr.s_addr };
return IPv4Address { network_ordered_address };
}
}

return Error::from_string_literal("Could not resolve to IPv4 address");
return Error::from_string_literal("Could not resolve to IPv4 or IPv6 address");
}

ErrorOr<void> Socket::connect_local(int fd, ByteString const& path)
Expand All @@ -96,8 +106,13 @@ ErrorOr<void> Socket::connect_local(int fd, ByteString const& path)

ErrorOr<void> Socket::connect_inet(int fd, SocketAddress const& address)
{
auto addr = address.to_sockaddr_in();
return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
if (address.type() == SocketAddress::Type::IPv6) {
auto addr = address.to_sockaddr_in6();
return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
} else {
auto addr = address.to_sockaddr_in();
return System::connect(fd, bit_cast<struct sockaddr*>(&addr), sizeof(addr));
}
}

ErrorOr<Bytes> PosixSocketHelper::read(Bytes buffer, int flags)
Expand Down Expand Up @@ -200,14 +215,19 @@ void PosixSocketHelper::setup_notifier()
ErrorOr<NonnullOwnPtr<TCPSocket>> TCPSocket::connect(ByteString const& host, u16 port)
{
auto ip_address = TRY(resolve_host(host, SocketType::Stream));
return connect(SocketAddress { ip_address, port });

return ip_address.visit([port](auto address) { return connect(SocketAddress { address, port }); });
}

ErrorOr<NonnullOwnPtr<TCPSocket>> TCPSocket::connect(SocketAddress const& address)
{
auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) TCPSocket()));

auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Stream));
auto socket_domain = SocketDomain::Inet6;
if (address.type() == SocketAddress::Type::IPv4)
socket_domain = SocketDomain::Inet;

auto fd = TRY(create_fd(socket_domain, SocketType::Stream));
socket->m_helper.set_fd(fd);

TRY(connect_inet(fd, address));
Expand Down Expand Up @@ -242,14 +262,19 @@ ErrorOr<size_t> PosixSocketHelper::pending_bytes() const
ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(ByteString const& host, u16 port, Optional<Duration> timeout)
{
auto ip_address = TRY(resolve_host(host, SocketType::Datagram));
return connect(SocketAddress { ip_address, port }, timeout);

return ip_address.visit([port, timeout](auto address) { return connect(SocketAddress { address, port }, timeout); });
}

ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(SocketAddress const& address, Optional<Duration> timeout)
{
auto socket = TRY(adopt_nonnull_own_or_enomem(new (nothrow) UDPSocket()));

auto fd = TRY(create_fd(SocketDomain::Inet, SocketType::Datagram));
auto socket_domain = SocketDomain::Inet6;
if (address.type() == SocketAddress::Type::IPv4)
socket_domain = SocketDomain::Inet;

auto fd = TRY(create_fd(socket_domain, SocketType::Datagram));
socket->m_helper.set_fd(fd);
if (timeout.has_value()) {
TRY(socket->m_helper.set_receive_timeout(timeout.value()));
Expand All @@ -258,6 +283,7 @@ ErrorOr<NonnullOwnPtr<UDPSocket>> UDPSocket::connect(SocketAddress const& addres
TRY(connect_inet(fd, address));

socket->setup_notifier();

return socket;
}

Expand Down
3 changes: 2 additions & 1 deletion Userland/Libraries/LibCore/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ class Socket : public Stream {

// FIXME: This will need to be updated when IPv6 socket arrives. Perhaps a
// base class for all address types is appropriate.
static ErrorOr<IPv4Address> resolve_host(ByteString const&, SocketType);
static ErrorOr<Variant<IPv4Address, IPv6Address>> resolve_host(ByteString const&, SocketType);

Function<void()> on_ready_to_read;

protected:
enum class SocketDomain {
Local,
Inet,
Inet6,
};

explicit Socket(PreventSIGPIPE prevent_sigpipe = PreventSIGPIPE::Yes)
Expand Down
47 changes: 40 additions & 7 deletions Userland/Libraries/LibCore/SocketAddress.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#pragma once

#include <AK/IPv4Address.h>
#include <AK/IPv6Address.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <string.h>
Expand All @@ -21,19 +22,33 @@ class SocketAddress {
enum class Type {
Invalid,
IPv4,
IPv6,
Local
};

SocketAddress() = default;
SocketAddress(IPv4Address const& address)
: m_type(Type::IPv4)
, m_ipv4_address(address)
, m_ip_address { address }
{
}

SocketAddress(IPv6Address const& address)
: m_type(Type::IPv6)
, m_ip_address { address }
{
}

SocketAddress(IPv4Address const& address, u16 port)
: m_type(Type::IPv4)
, m_ipv4_address(address)
, m_ip_address { address }
, m_port(port)
{
}

SocketAddress(IPv6Address const& address, u16 port)
: m_type(Type::IPv6)
, m_ip_address { address }
, m_port(port)
{
}
Expand All @@ -48,14 +63,18 @@ class SocketAddress {

Type type() const { return m_type; }
bool is_valid() const { return m_type != Type::Invalid; }
IPv4Address ipv4_address() const { return m_ipv4_address; }

IPv4Address ipv4_address() const { return m_ip_address.get<IPv4Address>(); }
IPv6Address ipv6_address() const { return m_ip_address.get<IPv6Address>(); }
u16 port() const { return m_port; }

ByteString to_byte_string() const
{
switch (m_type) {
case Type::IPv4:
return ByteString::formatted("{}:{}", m_ipv4_address, m_port);
return ByteString::formatted("{}:{}", m_ip_address.get<IPv4Address>(), m_port);
case Type::IPv6:
return ByteString::formatted("[{}]:{}", m_ip_address.get<IPv6Address>(), m_port);
case Type::Local:
return m_local_address;
default:
Expand All @@ -74,13 +93,25 @@ class SocketAddress {
return address;
}

sockaddr_in6 to_sockaddr_in6() const
{
VERIFY(type() == Type::IPv6);
sockaddr_in6 address {};
memset(&address, 0, sizeof(address));
address.sin6_family = AF_INET6;
address.sin6_port = htons(port());
auto ipv6_addr = ipv6_address();
memcpy(&address.sin6_addr, &ipv6_addr.to_in6_addr_t(), sizeof(address.sin6_addr));
return address;
}

sockaddr_in to_sockaddr_in() const
{
VERIFY(type() == Type::IPv4);
sockaddr_in address {};
address.sin_family = AF_INET;
address.sin_addr.s_addr = m_ipv4_address.to_in_addr_t();
address.sin_port = htons(m_port);
address.sin_port = htons(port());
address.sin_addr.s_addr = ipv4_address().to_in_addr_t();
return address;
}

Expand All @@ -89,7 +120,9 @@ class SocketAddress {

private:
Type m_type { Type::Invalid };
IPv4Address m_ipv4_address;

Variant<IPv4Address, IPv6Address> m_ip_address = IPv4Address();

u16 m_port { 0 };
ByteString m_local_address;
};
Expand Down

0 comments on commit ab82fc8

Please sign in to comment.