Skip to content

Commit

Permalink
Make Crystal::IOCP::OverlappedOperation abstract (#14987)
Browse files Browse the repository at this point in the history
This allows different overlapped operations to provide their own closure data, instead of putting everything in one big class, such as in #14979 (comment).
  • Loading branch information
HertzDevil authored Sep 9, 2024
1 parent bdddae7 commit 849e0d7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 54 deletions.
4 changes: 2 additions & 2 deletions src/crystal/system/win32/file_descriptor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ module Crystal::System::FileDescriptor
end

private def lock_file(handle, flags)
IOCP::OverlappedOperation.run(handle) do |operation|
IOCP::IOOverlappedOperation.run(handle) do |operation|
result = LibC.LockFileEx(handle, flags, 0, 0xFFFF_FFFF, 0xFFFF_FFFF, operation)

if result == 0
Expand All @@ -260,7 +260,7 @@ module Crystal::System::FileDescriptor
end

private def unlock_file(handle)
IOCP::OverlappedOperation.run(handle) do |operation|
IOCP::IOOverlappedOperation.run(handle) do |operation|
result = LibC.UnlockFileEx(handle, 0, 0xFFFF_FFFF, 0xFFFF_FFFF, operation)

if result == 0
Expand Down
121 changes: 73 additions & 48 deletions src/crystal/system/win32/iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -78,39 +78,66 @@ module Crystal::IOCP
end
end

class OverlappedOperation
abstract class OverlappedOperation
enum State
STARTED
DONE
end

abstract def wait_for_result(timeout, & : WinError ->)
private abstract def try_cancel : Bool

@overlapped = LibC::OVERLAPPED.new
@fiber = Fiber.current
@state : State = :started

def initialize(@handle : LibC::HANDLE)
def self.run(*args, **opts, &)
operation_storage = uninitialized ReferenceStorage(self)
operation = unsafe_construct(pointerof(operation_storage), *args, **opts)
yield operation
end

def initialize(handle : LibC::SOCKET)
@handle = LibC::HANDLE.new(handle)
def self.unbox(overlapped : LibC::OVERLAPPED*) : self
start = overlapped.as(Pointer(UInt8)) - offsetof(self, @overlapped)
Box(self).unbox(start.as(Pointer(Void)))
end

def self.run(handle, &)
operation_storage = uninitialized ReferenceStorage(OverlappedOperation)
operation = OverlappedOperation.unsafe_construct(pointerof(operation_storage), handle)
yield operation
def to_unsafe
pointerof(@overlapped)
end

def self.unbox(overlapped : LibC::OVERLAPPED*)
start = overlapped.as(Pointer(UInt8)) - offsetof(OverlappedOperation, @overlapped)
Box(OverlappedOperation).unbox(start.as(Pointer(Void)))
protected def schedule(&)
done!
yield @fiber
end

def to_unsafe
pointerof(@overlapped)
private def done!
@fiber.cancel_timeout
@state = :done
end

def wait_for_result(timeout, &)
private def wait_for_completion(timeout)
if timeout
sleep timeout
else
Fiber.suspend
end

unless @state.done?
if try_cancel
# Wait for cancellation to complete. We must not free the operation
# until it's completed.
Fiber.suspend
end
end
end
end

class IOOverlappedOperation < OverlappedOperation
def initialize(@handle : LibC::HANDLE)
end

def wait_for_result(timeout, & : WinError ->)
wait_for_completion(timeout)

result = LibC.GetOverlappedResult(@handle, self, out bytes, 0)
Expand All @@ -124,11 +151,35 @@ module Crystal::IOCP
bytes
end

def wait_for_wsa_result(timeout, &)
private def try_cancel : Bool
# Microsoft documentation:
# The application must not free or reuse the OVERLAPPED structure
# associated with the canceled I/O operations until they have completed
# (this does not apply to asynchronous operations that finished
# synchronously, as nothing would be queued to the IOCP)
ret = LibC.CancelIoEx(@handle, self)
if ret.zero?
case error = WinError.value
when .error_not_found?
# Operation has already completed, do nothing
return false
else
raise RuntimeError.from_os_error("CancelIoEx", os_error: error)
end
end
true
end
end

class WSAOverlappedOperation < OverlappedOperation
def initialize(@handle : LibC::SOCKET)
end

def wait_for_result(timeout, & : WinError ->)
wait_for_completion(timeout)

flags = 0_u32
result = LibC.WSAGetOverlappedResult(LibC::SOCKET.new(@handle.address), self, out bytes, false, pointerof(flags))
result = LibC.WSAGetOverlappedResult(@handle, self, out bytes, false, pointerof(flags))
if result.zero?
error = WinError.wsa_value
yield error
Expand All @@ -139,57 +190,31 @@ module Crystal::IOCP
bytes
end

protected def schedule(&)
done!
yield @fiber
end

def done!
@fiber.cancel_timeout
@state = :done
end

def try_cancel : Bool
private def try_cancel : Bool
# Microsoft documentation:
# The application must not free or reuse the OVERLAPPED structure
# associated with the canceled I/O operations until they have completed
# (this does not apply to asynchronous operations that finished
# synchronously, as nothing would be queued to the IOCP)
ret = LibC.CancelIoEx(@handle, self)
ret = LibC.CancelIoEx(Pointer(Void).new(@handle), self)
if ret.zero?
case error = WinError.value
when .error_not_found?
# Operation has already completed, do nothing
return false
else
raise RuntimeError.from_os_error("CancelIOEx", os_error: error)
raise RuntimeError.from_os_error("CancelIoEx", os_error: error)
end
end
true
end

def wait_for_completion(timeout)
if timeout
sleep timeout
else
Fiber.suspend
end

unless @state.done?
if try_cancel
# Wait for cancellation to complete. We must not free the operation
# until it's completed.
Fiber.suspend
end
end
end
end

def self.overlapped_operation(file_descriptor, method, timeout, *, offset = nil, writing = false, &)
handle = file_descriptor.windows_handle
seekable = LibC.SetFilePointerEx(handle, 0, out original_offset, IO::Seek::Current) != 0

OverlappedOperation.run(handle) do |operation|
IOOverlappedOperation.run(handle) do |operation|
overlapped = operation.to_unsafe
if seekable
start_offset = offset || original_offset
Expand Down Expand Up @@ -243,7 +268,7 @@ module Crystal::IOCP
end

def self.wsa_overlapped_operation(target, socket, method, timeout, connreset_is_error = true, &)
OverlappedOperation.run(socket) do |operation|
WSAOverlappedOperation.run(socket) do |operation|
result, value = yield operation

if result == LibC::SOCKET_ERROR
Expand All @@ -257,7 +282,7 @@ module Crystal::IOCP
return value
end

operation.wait_for_wsa_result(timeout) do |error|
operation.wait_for_result(timeout) do |error|
case error
when .wsa_io_incomplete?, .error_operation_aborted?
raise IO::TimeoutError.new("#{method} timed out")
Expand Down
8 changes: 4 additions & 4 deletions src/crystal/system/win32/socket.cr
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ module Crystal::System::Socket

# :nodoc:
def overlapped_connect(socket, method, timeout, &)
IOCP::OverlappedOperation.run(socket) do |operation|
IOCP::WSAOverlappedOperation.run(socket) do |operation|
result = yield operation

if result == 0
Expand All @@ -145,7 +145,7 @@ module Crystal::System::Socket
return nil
end

operation.wait_for_wsa_result(timeout) do |error|
operation.wait_for_result(timeout) do |error|
case error
when .wsa_io_incomplete?, .wsaeconnrefused?
return ::Socket::ConnectError.from_os_error(method, error)
Expand Down Expand Up @@ -192,7 +192,7 @@ module Crystal::System::Socket
end

def overlapped_accept(socket, method, &)
IOCP::OverlappedOperation.run(socket) do |operation|
IOCP::WSAOverlappedOperation.run(socket) do |operation|
result = yield operation

if result == 0
Expand All @@ -206,7 +206,7 @@ module Crystal::System::Socket
return true
end

operation.wait_for_wsa_result(read_timeout) do |error|
operation.wait_for_result(read_timeout) do |error|
case error
when .wsa_io_incomplete?, .wsaenotsock?
return false
Expand Down

0 comments on commit 849e0d7

Please sign in to comment.