Skip to content

Commit

Permalink
Extract select from src/channel.cr (#14912)
Browse files Browse the repository at this point in the history
  • Loading branch information
straight-shoota authored Aug 18, 2024
1 parent 7ee895f commit 75ced20
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 262 deletions.
263 changes: 1 addition & 262 deletions src/channel.cr
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
require "fiber"
require "crystal/spin_lock"
require "crystal/pointer_linked_list"
require "channel/select"

# A `Channel` enables concurrent communication between fibers.
#
Expand All @@ -26,106 +27,15 @@ class Channel(T)
@lock = Crystal::SpinLock.new
@queue : Deque(T)?

# :nodoc:
record NotReady
# :nodoc:
record UseDefault

# :nodoc:
module SelectAction(S)
abstract def execute : DeliveryState
abstract def wait(context : SelectContext(S))
abstract def wait_result_impl(context : SelectContext(S))
abstract def unwait_impl(context : SelectContext(S))
abstract def result : S
abstract def lock_object_id
abstract def lock
abstract def unlock

def create_context_and_wait(shared_state)
context = SelectContext.new(shared_state, self)
self.wait(context)
context
end

# wait_result overload allow implementors to define
# wait_result_impl with the right type and Channel.select_impl
# to allow dispatching over unions that will not happen
def wait_result(context : SelectContext)
raise "BUG: Unexpected call to #{typeof(self)}#wait_result(context : #{typeof(context)})"
end

def wait_result(context : SelectContext(S))
wait_result_impl(context)
end

# idem wait_result/wait_result_impl
def unwait(context : SelectContext)
raise "BUG: Unexpected call to #{typeof(self)}#unwait(context : #{typeof(context)})"
end

def unwait(context : SelectContext(S))
unwait_impl(context)
end

# Implementor that returns `Channel::UseDefault` in `#execute`
# must redefine `#default_result`
def default_result
raise "Unreachable"
end
end

private enum SelectState
None = 0
Active = 1
Done = 2
end

private class SelectContextSharedState
@state : Atomic(SelectState)

def initialize(value : SelectState)
@state = Atomic(SelectState).new(value)
end

def compare_and_set(cmp : SelectState, new : SelectState) : {SelectState, Bool}
@state.compare_and_set(cmp, new)
end
end

private class SelectContext(S)
@state : SelectContextSharedState
property action : SelectAction(S)
@activated = false

def initialize(@state, @action : SelectAction(S))
end

def activated? : Bool
@activated
end

def try_trigger : Bool
_, succeed = @state.compare_and_set(:active, :done)
if succeed
@activated = true
end
succeed
end
end

class ClosedError < Exception
def initialize(msg = "Channel is closed")
super(msg)
end
end

private enum DeliveryState
None
Delivered
Closed
end

private module SenderReceiverCloseAction
def close
self.state = DeliveryState::Closed
Expand Down Expand Up @@ -398,112 +308,6 @@ class Channel(T)
nil
end

# :nodoc:
def self.select(*ops : SelectAction)
self.select ops
end

# :nodoc:
def self.select(ops : Indexable(SelectAction))
i, m = select_impl(ops, false)
raise "BUG: Blocking select returned not ready status" if m.is_a?(NotReady)
return i, m
end

# :nodoc:
def self.non_blocking_select(*ops : SelectAction)
self.non_blocking_select ops
end

# :nodoc:
def self.non_blocking_select(ops : Indexable(SelectAction))
select_impl(ops, true)
end

private def self.select_impl(ops : Indexable(SelectAction), non_blocking)
# ops_locks is a duplicate of ops that can be sorted without disturbing the
# index positions of ops
if ops.responds_to?(:unstable_sort_by!)
# If the collection type implements `unstable_sort_by!` we can dup it.
# This applies to two types:
# * `Array`: `Array#to_a` does not dup and would return the same instance,
# thus we'd be sorting ops and messing up the index positions.
# * `StaticArray`: This avoids a heap allocation because we can dup a
# static array on the stack.
ops_locks = ops.dup
elsif ops.responds_to?(:to_static_array)
# If the collection type implements `to_static_array` we can create a
# copy without allocating an array. This applies to `Tuple` types, which
# the compiler generates for `select` expressions.
ops_locks = ops.to_static_array
else
ops_locks = ops.to_a
end

# Sort the operations by the channel they contain
# This is to avoid deadlocks between concurrent `select` calls
ops_locks.unstable_sort_by!(&.lock_object_id)

each_skip_duplicates(ops_locks, &.lock)

ops.each_with_index do |op, index|
state = op.execute

case state
in .delivered?
each_skip_duplicates(ops_locks, &.unlock)
return index, op.result
in .closed?
each_skip_duplicates(ops_locks, &.unlock)
return index, op.default_result
in .none?
# do nothing
end
end

if non_blocking
each_skip_duplicates(ops_locks, &.unlock)
return ops.size, NotReady.new
end

# Because `channel#close` may clean up a long list, `select_context.try_trigger` may
# be called after the select return. In order to prevent invalid address access,
# the state is allocated in the heap.
shared_state = SelectContextSharedState.new(SelectState::Active)
contexts = ops.map &.create_context_and_wait(shared_state)

each_skip_duplicates(ops_locks, &.unlock)
Fiber.suspend

contexts.each_with_index do |context, index|
op = ops[index]
op.lock
op.unwait(context)
op.unlock
end

contexts.each_with_index do |context, index|
if context.activated?
return index, ops[index].wait_result(context)
end
end

raise "BUG: Fiber was awaken from select but no action was activated"
end

private def self.each_skip_duplicates(ops_locks, &)
# Avoid deadlocks from trying to lock the same lock twice.
# `ops_lock` is sorted by `lock_object_id`, so identical onces will be in
# a row and we skip repeats while iterating.
last_lock_id = nil
ops_locks.each do |op|
if op.lock_object_id != last_lock_id
last_lock_id = op.lock_object_id
yield op
end
end
end

# :nodoc:
def send_select_action(value : T)
SendAction.new(self, value)
Expand Down Expand Up @@ -699,69 +503,4 @@ class Channel(T)
raise ClosedError.new
end
end

# :nodoc:
class TimeoutAction
include SelectAction(Nil)

# Total amount of time to wait
@timeout : Time::Span
@select_context : SelectContext(Nil)?

def initialize(@timeout : Time::Span)
end

def execute : DeliveryState
DeliveryState::None
end

def result : Nil
nil
end

def wait(context : SelectContext(Nil)) : Nil
@select_context = context
Fiber.timeout(@timeout, self)
end

def wait_result_impl(context : SelectContext(Nil))
nil
end

def unwait_impl(context : SelectContext(Nil))
Fiber.cancel_timeout
end

def lock_object_id : UInt64
self.object_id
end

def lock
end

def unlock
end

def time_expired(fiber : Fiber) : Nil
if @select_context.try &.try_trigger
fiber.enqueue
end
end
end
end

# Timeout keyword for use in `select`.
#
# ```
# select
# when x = ch.receive
# puts "got #{x}"
# when timeout(1.seconds)
# puts "timeout"
# end
# ```
#
# NOTE: It won't trigger if the `select` has an `else` case (i.e.: a non-blocking select).
def timeout_select_action(timeout : Time::Span) : Channel::TimeoutAction
Channel::TimeoutAction.new(timeout)
end
Loading

0 comments on commit 75ced20

Please sign in to comment.