From 18c28f9b2f643604d839d7c03508b6420b62f65e Mon Sep 17 00:00:00 2001 From: Quinton Miller Date: Wed, 4 Sep 2024 01:59:18 +0800 Subject: [PATCH] Support non-blocking `File#read_at` on Windows (#14958) --- spec/std/file_spec.cr | 20 +++++++------ src/crystal/system/unix/file_descriptor.cr | 6 ++-- src/crystal/system/win32/file_descriptor.cr | 33 +++++++++++---------- src/crystal/system/win32/iocp.cr | 14 +++++---- src/file/preader.cr | 2 +- 5 files changed, 42 insertions(+), 33 deletions(-) diff --git a/spec/std/file_spec.cr b/spec/std/file_spec.cr index 55a7b5d76494..eb740885cd69 100644 --- a/spec/std/file_spec.cr +++ b/spec/std/file_spec.cr @@ -1295,17 +1295,19 @@ describe "File" do it "reads at offset" do filename = datapath("test_file.txt") - File.open(filename) do |file| - file.read_at(6, 100) do |io| - io.gets_to_end.should eq("World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello Worl") - end + {true, false}.each do |blocking| + File.open(filename, blocking: blocking) do |file| + file.read_at(6, 100) do |io| + io.gets_to_end.should eq("World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello World\nHello Worl") + end - file.read_at(0, 240) do |io| - io.gets_to_end.should eq(File.read(filename)) - end + file.read_at(0, 240) do |io| + io.gets_to_end.should eq(File.read(filename)) + end - file.read_at(6_i64, 5_i64) do |io| - io.gets_to_end.should eq("World") + file.read_at(6_i64, 5_i64) do |io| + io.gets_to_end.should eq("World") + end end end end diff --git a/src/crystal/system/unix/file_descriptor.cr b/src/crystal/system/unix/file_descriptor.cr index d235114849b4..fc8839ac9e83 100644 --- a/src/crystal/system/unix/file_descriptor.cr +++ b/src/crystal/system/unix/file_descriptor.cr @@ -219,11 +219,11 @@ module Crystal::System::FileDescriptor {r, w} end - def self.pread(fd, buffer, offset) - bytes_read = LibC.pread(fd, buffer, buffer.size, offset).to_i64 + def self.pread(file, buffer, offset) + bytes_read = LibC.pread(file.fd, buffer, buffer.size, offset).to_i64 if bytes_read == -1 - raise IO::Error.from_errno "Error reading file" + raise IO::Error.from_errno("Error reading file", target: file) end bytes_read diff --git a/src/crystal/system/win32/file_descriptor.cr b/src/crystal/system/win32/file_descriptor.cr index 37813307191f..f4e9200a0488 100644 --- a/src/crystal/system/win32/file_descriptor.cr +++ b/src/crystal/system/win32/file_descriptor.cr @@ -120,10 +120,6 @@ module Crystal::System::FileDescriptor end protected def windows_handle - FileDescriptor.windows_handle(fd) - end - - def self.windows_handle(fd) LibC::HANDLE.new(fd) end @@ -278,19 +274,26 @@ module Crystal::System::FileDescriptor {r, w} end - def self.pread(fd, buffer, offset) - handle = windows_handle(fd) + def self.pread(file, buffer, offset) + handle = file.windows_handle - overlapped = LibC::OVERLAPPED.new - overlapped.union.offset.offset = LibC::DWORD.new!(offset) - overlapped.union.offset.offsetHigh = LibC::DWORD.new!(offset >> 32) - if LibC.ReadFile(handle, buffer, buffer.size, out bytes_read, pointerof(overlapped)) == 0 - error = WinError.value - return 0_i64 if error == WinError::ERROR_HANDLE_EOF - raise IO::Error.from_os_error "Error reading file", error, target: self - end + if file.system_blocking? + overlapped = LibC::OVERLAPPED.new + overlapped.union.offset.offset = LibC::DWORD.new!(offset) + overlapped.union.offset.offsetHigh = LibC::DWORD.new!(offset >> 32) + if LibC.ReadFile(handle, buffer, buffer.size, out bytes_read, pointerof(overlapped)) == 0 + error = WinError.value + return 0_i64 if error == WinError::ERROR_HANDLE_EOF + raise IO::Error.from_os_error "Error reading file", error, target: file + end - bytes_read.to_i64 + bytes_read.to_i64 + else + IOCP.overlapped_operation(file, "ReadFile", file.read_timeout, offset: offset) do |overlapped| + ret = LibC.ReadFile(handle, buffer, buffer.size, out byte_count, overlapped) + {ret, byte_count} + end.to_i64 + end end def self.from_stdio(fd) diff --git a/src/crystal/system/win32/iocp.cr b/src/crystal/system/win32/iocp.cr index af8f778290f3..6f5746954277 100644 --- a/src/crystal/system/win32/iocp.cr +++ b/src/crystal/system/win32/iocp.cr @@ -168,15 +168,16 @@ module Crystal::IOCP end end - def self.overlapped_operation(file_descriptor, method, timeout, *, writing = false, &) + def self.overlapped_operation(file_descriptor, method, timeout, *, offset = nil, writing = false, &) handle = file_descriptor.windows_handle seekable = LibC.SetFilePointerEx(handle, 0, out original_offset, IO::Seek::Current) != 0 OverlappedOperation.run(handle) do |operation| overlapped = operation.to_unsafe if seekable - overlapped.value.union.offset.offset = LibC::DWORD.new!(original_offset) - overlapped.value.union.offset.offsetHigh = LibC::DWORD.new!(original_offset >> 32) + start_offset = offset || original_offset + overlapped.value.union.offset.offset = LibC::DWORD.new!(start_offset) + overlapped.value.union.offset.offsetHigh = LibC::DWORD.new!(start_offset >> 32) end result, value = yield operation @@ -215,8 +216,11 @@ module Crystal::IOCP # operation completed asynchronously; seek to the original file position # plus the number of bytes read or written (other operations might have - # moved the file pointer so we don't use `IO::Seek::Current` here) - LibC.SetFilePointerEx(handle, original_offset + byte_count, nil, IO::Seek::Set) if seekable + # moved the file pointer so we don't use `IO::Seek::Current` here), unless + # we are calling `Crystal::System::FileDescriptor.pread` + if seekable && !offset + LibC.SetFilePointerEx(handle, original_offset + byte_count, nil, IO::Seek::Set) + end byte_count end end diff --git a/src/file/preader.cr b/src/file/preader.cr index d366457314ce..9f7d09643305 100644 --- a/src/file/preader.cr +++ b/src/file/preader.cr @@ -20,7 +20,7 @@ class File::PReader < IO count = slice.size count = Math.min(count, @bytesize - @pos) - bytes_read = Crystal::System::FileDescriptor.pread(@file.fd, slice[0, count], @offset + @pos) + bytes_read = Crystal::System::FileDescriptor.pread(@file, slice[0, count], @offset + @pos) @pos += bytes_read