Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract select from src/channel.cr #14912

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 1 addition & 262 deletions src/channel.cr
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
require "fiber"
require "crystal/spin_lock"
require "crystal/pointer_linked_list"
require "channel/select"

# A `Channel` enables concurrent communication between fibers.
#
Expand All @@ -26,106 +27,15 @@ class Channel(T)
@lock = Crystal::SpinLock.new
@queue : Deque(T)?

# :nodoc:
record NotReady
# :nodoc:
record UseDefault

# :nodoc:
module SelectAction(S)
abstract def execute : DeliveryState
abstract def wait(context : SelectContext(S))
abstract def wait_result_impl(context : SelectContext(S))
abstract def unwait_impl(context : SelectContext(S))
abstract def result : S
abstract def lock_object_id
abstract def lock
abstract def unlock

def create_context_and_wait(shared_state)
context = SelectContext.new(shared_state, self)
self.wait(context)
context
end

# wait_result overload allow implementors to define
# wait_result_impl with the right type and Channel.select_impl
# to allow dispatching over unions that will not happen
def wait_result(context : SelectContext)
raise "BUG: Unexpected call to #{typeof(self)}#wait_result(context : #{typeof(context)})"
end

def wait_result(context : SelectContext(S))
wait_result_impl(context)
end

# idem wait_result/wait_result_impl
def unwait(context : SelectContext)
raise "BUG: Unexpected call to #{typeof(self)}#unwait(context : #{typeof(context)})"
end

def unwait(context : SelectContext(S))
unwait_impl(context)
end

# Implementor that returns `Channel::UseDefault` in `#execute`
# must redefine `#default_result`
def default_result
raise "Unreachable"
end
end

private enum SelectState
None = 0
Active = 1
Done = 2
end

private class SelectContextSharedState
@state : Atomic(SelectState)

def initialize(value : SelectState)
@state = Atomic(SelectState).new(value)
end

def compare_and_set(cmp : SelectState, new : SelectState) : {SelectState, Bool}
@state.compare_and_set(cmp, new)
end
end

private class SelectContext(S)
@state : SelectContextSharedState
property action : SelectAction(S)
@activated = false

def initialize(@state, @action : SelectAction(S))
end

def activated? : Bool
@activated
end

def try_trigger : Bool
_, succeed = @state.compare_and_set(:active, :done)
if succeed
@activated = true
end
succeed
end
end

class ClosedError < Exception
def initialize(msg = "Channel is closed")
super(msg)
end
end

private enum DeliveryState
None
Delivered
Closed
end

private module SenderReceiverCloseAction
def close
self.state = DeliveryState::Closed
Expand Down Expand Up @@ -398,112 +308,6 @@ class Channel(T)
nil
end

# :nodoc:
def self.select(*ops : SelectAction)
self.select ops
end

# :nodoc:
def self.select(ops : Indexable(SelectAction))
i, m = select_impl(ops, false)
raise "BUG: Blocking select returned not ready status" if m.is_a?(NotReady)
return i, m
end

# :nodoc:
def self.non_blocking_select(*ops : SelectAction)
self.non_blocking_select ops
end

# :nodoc:
def self.non_blocking_select(ops : Indexable(SelectAction))
select_impl(ops, true)
end

private def self.select_impl(ops : Indexable(SelectAction), non_blocking)
# ops_locks is a duplicate of ops that can be sorted without disturbing the
# index positions of ops
if ops.responds_to?(:unstable_sort_by!)
# If the collection type implements `unstable_sort_by!` we can dup it.
# This applies to two types:
# * `Array`: `Array#to_a` does not dup and would return the same instance,
# thus we'd be sorting ops and messing up the index positions.
# * `StaticArray`: This avoids a heap allocation because we can dup a
# static array on the stack.
ops_locks = ops.dup
elsif ops.responds_to?(:to_static_array)
# If the collection type implements `to_static_array` we can create a
# copy without allocating an array. This applies to `Tuple` types, which
# the compiler generates for `select` expressions.
ops_locks = ops.to_static_array
else
ops_locks = ops.to_a
end

# Sort the operations by the channel they contain
# This is to avoid deadlocks between concurrent `select` calls
ops_locks.unstable_sort_by!(&.lock_object_id)

each_skip_duplicates(ops_locks, &.lock)

ops.each_with_index do |op, index|
state = op.execute

case state
in .delivered?
each_skip_duplicates(ops_locks, &.unlock)
return index, op.result
in .closed?
each_skip_duplicates(ops_locks, &.unlock)
return index, op.default_result
in .none?
# do nothing
end
end

if non_blocking
each_skip_duplicates(ops_locks, &.unlock)
return ops.size, NotReady.new
end

# Because `channel#close` may clean up a long list, `select_context.try_trigger` may
# be called after the select return. In order to prevent invalid address access,
# the state is allocated in the heap.
shared_state = SelectContextSharedState.new(SelectState::Active)
contexts = ops.map &.create_context_and_wait(shared_state)

each_skip_duplicates(ops_locks, &.unlock)
Fiber.suspend

contexts.each_with_index do |context, index|
op = ops[index]
op.lock
op.unwait(context)
op.unlock
end

contexts.each_with_index do |context, index|
if context.activated?
return index, ops[index].wait_result(context)
end
end

raise "BUG: Fiber was awaken from select but no action was activated"
end

private def self.each_skip_duplicates(ops_locks, &)
# Avoid deadlocks from trying to lock the same lock twice.
# `ops_lock` is sorted by `lock_object_id`, so identical onces will be in
# a row and we skip repeats while iterating.
last_lock_id = nil
ops_locks.each do |op|
if op.lock_object_id != last_lock_id
last_lock_id = op.lock_object_id
yield op
end
end
end

# :nodoc:
def send_select_action(value : T)
SendAction.new(self, value)
Expand Down Expand Up @@ -699,69 +503,4 @@ class Channel(T)
raise ClosedError.new
end
end

# :nodoc:
class TimeoutAction
include SelectAction(Nil)

# Total amount of time to wait
@timeout : Time::Span
@select_context : SelectContext(Nil)?

def initialize(@timeout : Time::Span)
end

def execute : DeliveryState
DeliveryState::None
end

def result : Nil
nil
end

def wait(context : SelectContext(Nil)) : Nil
@select_context = context
Fiber.timeout(@timeout, self)
end

def wait_result_impl(context : SelectContext(Nil))
nil
end

def unwait_impl(context : SelectContext(Nil))
Fiber.cancel_timeout
end

def lock_object_id : UInt64
self.object_id
end

def lock
end

def unlock
end

def time_expired(fiber : Fiber) : Nil
if @select_context.try &.try_trigger
fiber.enqueue
end
end
end
end

# Timeout keyword for use in `select`.
#
# ```
# select
# when x = ch.receive
# puts "got #{x}"
# when timeout(1.seconds)
# puts "timeout"
# end
# ```
#
# NOTE: It won't trigger if the `select` has an `else` case (i.e.: a non-blocking select).
def timeout_select_action(timeout : Time::Span) : Channel::TimeoutAction
Channel::TimeoutAction.new(timeout)
end
Loading
Loading