Skip to content

Commit

Permalink
Reserve certain prefixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Feb 11, 2025
1 parent 7162d3e commit 7b6e8c7
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 35 deletions.
3 changes: 3 additions & 0 deletions temporalio/lib/temporalio/activity/definition.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# frozen_string_literal: true

require 'temporalio/internal/proto_utils'

module Temporalio
module Activity
# Base class for all activities.
Expand Down Expand Up @@ -182,6 +184,7 @@ def initialize(
@executor = executor
@cancel_raise = cancel_raise
@raw_args = raw_args
Internal::ProtoUtils.assert_non_reserved_name(name)
end
end
end
Expand Down
16 changes: 16 additions & 0 deletions temporalio/lib/temporalio/internal/proto_utils.rb
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ def self.convert_to_payload_array(converter, values)
converter.to_payloads(values).payloads.to_ary
end

def self.assert_non_reserved_name(name)
name = name&.to_s # In case it's a symbol or not present
return unless name
raise "'#{name}' cannot start with '__temporal_'" if name.start_with?('__temporal_')
# Might as well disable __stack_trace and __enhanced_stack_trace everywhere even though technically it's only
# reserved for queries
raise "'#{name}' name invalid" if name == '__stack_trace' || name == '__enhanced_stack_trace'
end

def self.reserved_name?(name)
name = name&.to_s # In case it's a symbol or not present
return false unless name

name.start_with?('__temporal_') || name == '__stack_trace' || name == '__enhanced_stack_trace'
end

class LazyMemo
def initialize(raw_memo, converter)
@raw_memo = raw_memo
Expand Down
6 changes: 4 additions & 2 deletions temporalio/lib/temporalio/internal/worker/activity_worker.rb
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def handle_task(task)
def handle_start_task(task_token, start)
set_running_activity(task_token, nil)

# Find activity definition, falling back to dynamic if present
defn = @activities[start.activity_type] || @activities[nil]
# Find activity definition, falling back to dynamic if not found and not reserved name
defn = @activities[start.activity_type]
defn = @activities[nil] if !defn && !Internal::ProtoUtils.reserved_name?(start.activity_type)

if defn.nil?
raise Error::ApplicationError.new(
"Activity #{start.activity_type} for workflow #{start.workflow_execution.workflow_id} " \
Expand Down
69 changes: 38 additions & 31 deletions temporalio/lib/temporalio/internal/worker/workflow_instance.rb
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,10 @@ def apply(job)
end

def apply_signal(job)
defn = signal_handlers[job.signal_name] || signal_handlers[nil]
# Get signal definition, falling back to dynamic if not present and not reserved
defn = signal_handlers[job.signal_name]
defn = signal_handlers[nil] if !defn && !Internal::ProtoUtils.reserved_name?(job.signal_name)

handler_exec =
if defn
HandlerExecution.new(name: job.signal_name, update_id: nil, unfinished_policy: defn.unfinished_policy)
Expand Down Expand Up @@ -381,37 +384,38 @@ def apply_signal(job)
end

def apply_query(job)
# TODO(cretz): __temporal_workflow_metadata
defn = case job.query_type
when '__stack_trace'
Workflow::Definition::Query.new(
name: '__stack_trace',
to_invoke: proc { scheduler.stack_trace }
)
else
query_handlers[job.query_type] || query_handlers[nil]
end
schedule do
unless defn
raise "Query handler for #{job.query_type} expected but not found, " \
"known queries: [#{query_handlers.keys.compact.sort.join(', ')}]"
end
# If it's a built-in, run it without interceptors, otherwise do normal behavior
# TODO(cretz): __temporal_workflow_metadata
result = if job.query_type == '__stack_trace'
scheduler.stack_trace
else
# Get query definition, falling back to dynamic if not present and not reserved
defn = query_handlers[job.query_type]
defn = query_handlers[nil] if !defn && !Internal::ProtoUtils.reserved_name?(job.query_type)

unless defn
raise "Query handler for #{job.query_type} expected but not found, " \
"known queries: [#{query_handlers.keys.compact.sort.join(', ')}]"
end

with_context_frozen do
@inbound.handle_query(
Temporalio::Worker::Interceptor::Workflow::HandleQueryInput.new(
id: job.query_id,
query: job.query_type,
args: begin
convert_handler_args(payload_array: job.arguments, defn:)
rescue StandardError => e
raise "Failed converting query input arguments: #{e}"
end,
definition: defn,
headers: ProtoUtils.headers_from_proto_map(job.headers, @payload_converter) || {}
)
)
end
end

result = with_context_frozen do
@inbound.handle_query(
Temporalio::Worker::Interceptor::Workflow::HandleQueryInput.new(
id: job.query_id,
query: job.query_type,
args: begin
convert_handler_args(payload_array: job.arguments, defn:)
rescue StandardError => e
raise "Failed converting query input arguments: #{e}"
end,
definition: defn,
headers: ProtoUtils.headers_from_proto_map(job.headers, @payload_converter) || {}
)
)
end
add_command(
Bridge::Api::WorkflowCommands::WorkflowCommand.new(
respond_to_query: Bridge::Api::WorkflowCommands::QueryResult.new(
Expand All @@ -435,7 +439,10 @@ def apply_query(job)
end

def apply_update(job)
defn = update_handlers[job.name] || update_handlers[nil]
# Get update definition, falling back to dynamic if not present and not reserved
defn = update_handlers[job.name]
defn = update_handlers[nil] if !defn && !Internal::ProtoUtils.reserved_name?(job.name)

handler_exec =
(HandlerExecution.new(name: job.name, update_id: job.id, unfinished_policy: defn.unfinished_policy) if defn)
schedule(handler_exec:) do
Expand Down
3 changes: 3 additions & 0 deletions temporalio/lib/temporalio/worker.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
require 'temporalio/error'
require 'temporalio/internal/bridge'
require 'temporalio/internal/bridge/worker'
require 'temporalio/internal/proto_utils'
require 'temporalio/internal/worker/activity_worker'
require 'temporalio/internal/worker/multi_runner'
require 'temporalio/internal/worker/workflow_instance'
Expand Down Expand Up @@ -381,6 +382,8 @@ def initialize(
)
raise ArgumentError, 'Must have at least one activity or workflow' if activities.empty? && workflows.empty?

Internal::ProtoUtils.assert_non_reserved_name(task_queue)

@options = Options.new(
client:,
task_queue:,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ def create_instance(initial_activation, worker_state)
raise 'Missing initialize job in initial activation' unless init_job

# Obtain definition
definition = worker_state.workflow_definitions[init_job.workflow_type] ||
worker_state.workflow_definitions[nil]
definition = worker_state.workflow_definitions[init_job.workflow_type]
# If not present and not reserved, try dynamic
if !definition && !Internal::ProtoUtils.reserved_name?(init_job.workflow_type)
definition = worker_state.workflow_definitions[nil]
end

unless definition
raise Error::ApplicationError.new(
"Workflow type #{init_job.workflow_type} is not registered on this worker, available workflows: " +
Expand Down
5 changes: 5 additions & 0 deletions temporalio/lib/temporalio/workflow/definition.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# frozen_string_literal: true

require 'temporalio/internal/proto_utils'
require 'temporalio/workflow'
require 'temporalio/workflow/handler_unfinished_policy'

Expand Down Expand Up @@ -430,6 +431,7 @@ def initialize(
@signals = signals.dup.freeze
@queries = queries.dup.freeze
@updates = updates.dup.freeze
Internal::ProtoUtils.assert_non_reserved_name(name)
end

# @return [String] Workflow name.
Expand Down Expand Up @@ -473,6 +475,7 @@ def initialize(
@to_invoke = to_invoke
@raw_args = raw_args
@unfinished_policy = unfinished_policy
Internal::ProtoUtils.assert_non_reserved_name(name)
end
end

Expand Down Expand Up @@ -507,6 +510,7 @@ def initialize(
@name = name
@to_invoke = to_invoke
@raw_args = raw_args
Internal::ProtoUtils.assert_non_reserved_name(name)
end
end

Expand Down Expand Up @@ -548,6 +552,7 @@ def initialize(
@raw_args = raw_args
@unfinished_policy = unfinished_policy
@validator_to_invoke = validator_to_invoke
Internal::ProtoUtils.assert_non_reserved_name(name)
end

# @!visibility private
Expand Down
3 changes: 3 additions & 0 deletions temporalio/sig/temporalio/internal/proto_utils.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ module Temporalio
Array[Object?] values
) -> Array[untyped]

def self.assert_non_reserved_name: (String | Symbol | nil name) -> void
def self.reserved_name?: (String | Symbol | nil name) -> bool

class LazyMemo
def initialize: (
untyped? raw_memo,
Expand Down
11 changes: 11 additions & 0 deletions temporalio/test/worker_activity_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,17 @@ def test_client_access
assert_equal 'ClientAccessActivity', execute_activity(ClientAccessActivity)
end

class ReservedNameActivity < Temporalio::Activity::Definition
activity_name '__temporal_activity'

def execute; end
end

def test_reserved_name
err = assert_raises { Temporalio::Activity::Definition::Info.from_activity(ReservedNameActivity) }
assert_includes err.message, "'__temporal_activity' cannot start with '__temporal_'"
end

# steep:ignore
def execute_activity(
activity,
Expand Down
Loading

0 comments on commit 7b6e8c7

Please sign in to comment.