Skip to content

Commit

Permalink
Async DNS resolution on Windows (#14979)
Browse files Browse the repository at this point in the history
Uses `GetAddrInfoExW` with IOCP
  • Loading branch information
HertzDevil authored Sep 13, 2024
1 parent 8b1a391 commit dea39ad
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 31 deletions.
32 changes: 28 additions & 4 deletions spec/std/socket/addrinfo_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ describe Socket::Addrinfo, tags: "network" do
end
end
end

it "raises helpful message on getaddrinfo failure" do
expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM)
end
end

{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::STREAM, timeout: 0.milliseconds)
end
end
{% end %}
end

describe ".tcp" do
Expand All @@ -37,11 +51,13 @@ describe Socket::Addrinfo, tags: "network" do
end
end

it "raises helpful message on getaddrinfo failure" do
expect_raises(Socket::Addrinfo::Error, "Hostname lookup for badhostname failed: ") do
Socket::Addrinfo.resolve("badhostname", 80, type: Socket::Type::DGRAM)
{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.tcp("badhostname", 80, timeout: 0.milliseconds)
end
end
end
{% end %}
end

describe ".udp" do
Expand All @@ -56,6 +72,14 @@ describe Socket::Addrinfo, tags: "network" do
typeof(addrinfo).should eq(Socket::Addrinfo)
end
end

{% if flag?(:win32) %}
it "raises timeout error" do
expect_raises(IO::TimeoutError) do
Socket::Addrinfo.udp("badhostname", 80, timeout: 0.milliseconds)
end
end
{% end %}
end

describe "#ip_address" do
Expand Down
6 changes: 5 additions & 1 deletion src/crystal/system/addrinfo.cr
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ end
{% elsif flag?(:unix) %}
require "./unix/addrinfo"
{% elsif flag?(:win32) %}
require "./win32/addrinfo"
{% if flag?(:win7) %}
require "./win32/addrinfo_win7"
{% else %}
require "./win32/addrinfo"
{% end %}
{% else %}
{% raise "No Crystal::System::Addrinfo implementation available" %}
{% end %}
43 changes: 35 additions & 8 deletions src/crystal/system/win32/addrinfo.cr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Crystal::System::Addrinfo
alias Handle = LibC::Addrinfo*
alias Handle = LibC::ADDRINFOEXW*

@addr : LibC::SockaddrIn6

Expand Down Expand Up @@ -30,7 +30,7 @@ module Crystal::System::Addrinfo
end

def self.getaddrinfo(domain, service, family, type, protocol, timeout) : Handle
hints = LibC::Addrinfo.new
hints = LibC::ADDRINFOEXW.new
hints.ai_family = (family || ::Socket::Family::UNSPEC).to_i32
hints.ai_socktype = type
hints.ai_protocol = protocol
Expand All @@ -43,19 +43,46 @@ module Crystal::System::Addrinfo
end
end

ret = LibC.getaddrinfo(domain, service.to_s, pointerof(hints), out ptr)
unless ret.zero?
error = WinError.new(ret.to_u32!)
raise ::Socket::Addrinfo::Error.from_os_error(nil, error, domain: domain, type: type, protocol: protocol, service: service)
Crystal::IOCP::GetAddrInfoOverlappedOperation.run(Crystal::EventLoop.current.iocp) do |operation|
completion_routine = LibC::LPLOOKUPSERVICE_COMPLETION_ROUTINE.new do |dwError, dwBytes, lpOverlapped|
orig_operation = Crystal::IOCP::GetAddrInfoOverlappedOperation.unbox(lpOverlapped)
LibC.PostQueuedCompletionStatus(orig_operation.iocp, 0, 0, lpOverlapped)
end

# NOTE: we handle the timeout ourselves so we don't pass a `LibC::Timeval`
# to Win32 here
result = LibC.GetAddrInfoExW(
Crystal::System.to_wstr(domain), Crystal::System.to_wstr(service.to_s), LibC::NS_DNS, nil, pointerof(hints),
out addrinfos, nil, operation, completion_routine, out cancel_handle)

if result == 0
return addrinfos
else
case error = WinError.new(result.to_u32!)
when .wsa_io_pending?
# used in `Crystal::IOCP::OverlappedOperation#try_cancel_getaddrinfo`
operation.cancel_handle = cancel_handle
else
raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service)
end
end

operation.wait_for_result(timeout) do |error|
case error
when .wsa_e_cancelled?
raise IO::TimeoutError.new("GetAddrInfoExW timed out")
else
raise ::Socket::Addrinfo::Error.from_os_error("GetAddrInfoExW", error, domain: domain, type: type, protocol: protocol, service: service)
end
end
end
ptr
end

def self.next_addrinfo(addrinfo : Handle) : Handle
addrinfo.value.ai_next
end

def self.free_addrinfo(addrinfo : Handle)
LibC.freeaddrinfo(addrinfo)
LibC.FreeAddrInfoExW(addrinfo)
end
end
61 changes: 61 additions & 0 deletions src/crystal/system/win32/addrinfo_win7.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
module Crystal::System::Addrinfo
alias Handle = LibC::Addrinfo*

@addr : LibC::SockaddrIn6

protected def initialize(addrinfo : Handle)
@family = ::Socket::Family.from_value(addrinfo.value.ai_family)
@type = ::Socket::Type.from_value(addrinfo.value.ai_socktype)
@protocol = ::Socket::Protocol.from_value(addrinfo.value.ai_protocol)
@size = addrinfo.value.ai_addrlen.to_i

@addr = uninitialized LibC::SockaddrIn6

case @family
when ::Socket::Family::INET6
addrinfo.value.ai_addr.as(LibC::SockaddrIn6*).copy_to(pointerof(@addr).as(LibC::SockaddrIn6*), 1)
when ::Socket::Family::INET
addrinfo.value.ai_addr.as(LibC::SockaddrIn*).copy_to(pointerof(@addr).as(LibC::SockaddrIn*), 1)
else
# TODO: (asterite) UNSPEC and UNIX unsupported?
end
end

def system_ip_address : ::Socket::IPAddress
::Socket::IPAddress.from(to_unsafe, size)
end

def to_unsafe
pointerof(@addr).as(LibC::Sockaddr*)
end

def self.getaddrinfo(domain, service, family, type, protocol, timeout) : Handle
hints = LibC::Addrinfo.new
hints.ai_family = (family || ::Socket::Family::UNSPEC).to_i32
hints.ai_socktype = type
hints.ai_protocol = protocol
hints.ai_flags = 0

if service.is_a?(Int)
hints.ai_flags |= LibC::AI_NUMERICSERV
if service < 0
raise ::Socket::Addrinfo::Error.from_os_error(nil, WinError::WSATYPE_NOT_FOUND, domain: domain, type: type, protocol: protocol, service: service)
end
end

ret = LibC.getaddrinfo(domain, service.to_s, pointerof(hints), out ptr)
unless ret.zero?
error = WinError.new(ret.to_u32!)
raise ::Socket::Addrinfo::Error.from_os_error(nil, error, domain: domain, type: type, protocol: protocol, service: service)
end
ptr
end

def self.next_addrinfo(addrinfo : Handle) : Handle
addrinfo.value.ai_next
end

def self.free_addrinfo(addrinfo : Handle)
LibC.freeaddrinfo(addrinfo)
end
end
36 changes: 36 additions & 0 deletions src/crystal/system/win32/iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,42 @@ module Crystal::IOCP
end
end

class GetAddrInfoOverlappedOperation < OverlappedOperation
getter iocp
setter cancel_handle : LibC::HANDLE = LibC::INVALID_HANDLE_VALUE

def initialize(@iocp : LibC::HANDLE)
end

def wait_for_result(timeout, & : WinError ->)
wait_for_completion(timeout)

result = LibC.GetAddrInfoExOverlappedResult(self)
unless result.zero?
error = WinError.new(result.to_u32!)
yield error

raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExOverlappedResult", error)
end

@overlapped.union.pointer.as(LibC::ADDRINFOEXW**).value
end

private def try_cancel : Bool
ret = LibC.GetAddrInfoExCancel(pointerof(@cancel_handle))
unless ret.zero?
case error = WinError.new(ret.to_u32!)
when .wsa_invalid_handle?
# Operation has already completed, do nothing
return false
else
raise Socket::Addrinfo::Error.from_os_error("GetAddrInfoExCancel", error)
end
end
true
end
end

def self.overlapped_operation(file_descriptor, method, timeout, *, offset = nil, writing = false, &)
handle = file_descriptor.windows_handle
seekable = LibC.SetFilePointerEx(handle, 0, out original_offset, IO::Seek::Current) != 0
Expand Down
8 changes: 4 additions & 4 deletions src/http/client.cr
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,10 @@ class HTTP::Client
# ```
setter connect_timeout : Time::Span?

# **This method has no effect right now**
#
# Sets the number of seconds to wait when resolving a name, before raising an `IO::TimeoutError`.
#
# NOTE: *dns_timeout* is currently only supported on Windows.
#
# ```
# require "http/client"
#
Expand All @@ -363,10 +363,10 @@ class HTTP::Client
self.dns_timeout = dns_timeout.seconds
end

# **This method has no effect right now**
#
# Sets the number of seconds to wait when resolving a name with a `Time::Span`, before raising an `IO::TimeoutError`.
#
# NOTE: *dns_timeout* is currently only supported on Windows.
#
# ```
# require "http/client"
#
Expand Down
7 changes: 7 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/winsock2.cr
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ lib LibC
lpVendorInfo : Char*
end

NS_DNS = 12_u32

INVALID_SOCKET = ~SOCKET.new(0)
SOCKET_ERROR = -1

Expand Down Expand Up @@ -111,6 +113,11 @@ lib LibC

alias WSAOVERLAPPED_COMPLETION_ROUTINE = Proc(DWORD, DWORD, WSAOVERLAPPED*, DWORD, Void)

struct Timeval
tv_sec : Long
tv_usec : Long
end

struct Linger
l_onoff : UShort
l_linger : UShort
Expand Down
14 changes: 14 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/ws2def.cr
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,18 @@ lib LibC
ai_addr : Sockaddr*
ai_next : Addrinfo*
end

struct ADDRINFOEXW
ai_flags : Int
ai_family : Int
ai_socktype : Int
ai_protocol : Int
ai_addrlen : SizeT
ai_canonname : LPWSTR
ai_addr : Sockaddr*
ai_blob : Void*
ai_bloblen : SizeT
ai_provider : GUID*
ai_next : ADDRINFOEXW*
end
end
20 changes: 20 additions & 0 deletions src/lib_c/x86_64-windows-msvc/c/ws2tcpip.cr
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,24 @@ lib LibC
fun getaddrinfo(pNodeName : Char*, pServiceName : Char*, pHints : Addrinfo*, ppResult : Addrinfo**) : Int
fun inet_ntop(family : Int, pAddr : Void*, pStringBuf : Char*, stringBufSize : SizeT) : Char*
fun inet_pton(family : Int, pszAddrString : Char*, pAddrBuf : Void*) : Int

fun FreeAddrInfoExW(pAddrInfoEx : ADDRINFOEXW*)

alias LPLOOKUPSERVICE_COMPLETION_ROUTINE = DWORD, DWORD, WSAOVERLAPPED* ->

fun GetAddrInfoExW(
pName : LPWSTR,
pServiceName : LPWSTR,
dwNameSpace : DWORD,
lpNspId : GUID*,
hints : ADDRINFOEXW*,
ppResult : ADDRINFOEXW**,
timeout : Timeval*,
lpOverlapped : OVERLAPPED*,
lpCompletionRoutine : LPLOOKUPSERVICE_COMPLETION_ROUTINE,
lpHandle : HANDLE*,
) : Int

fun GetAddrInfoExOverlappedResult(lpOverlapped : OVERLAPPED*) : Int
fun GetAddrInfoExCancel(lpHandle : HANDLE*) : Int
end
Loading

0 comments on commit dea39ad

Please sign in to comment.