From 89ef696077c6e9dacfeda828c898a4fd7f9d57e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 16:39:51 -0500 Subject: [PATCH] feat: introduce `is_reactant_primitive` to define additional primitives --- ext/ReactantCUDAExt.jl | 4 ++-- src/Overlay.jl | 6 +++--- src/Reactant.jl | 22 +++++++++++++++++++++- src/TracedRArray.jl | 12 +++++++----- src/TracedRNumber.jl | 9 +-------- src/TracedUtils.jl | 16 +++++++++++----- src/Tracing.jl | 39 +++++++++++++++++++++------------------ 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index e5f6e4fc3..a8095f763 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive + if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T) return Reactant.ConcreteRArray{T,N} else TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) @@ -795,7 +795,7 @@ function Reactant.make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive + if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT)) return seen[prev] = Reactant.ConcreteRArray(Array(prev)) end TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers) diff --git a/src/Overlay.jl b/src/Overlay.jl index 5d9b85c83..dbd7e7e5e 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T, dims) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T) end @warn "Reactant doesn't support sampling of $(T) with the current \ diff --git a/src/Reactant.jl b/src/Reactant.jl index 41f6ab929..13286afd6 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -58,7 +58,15 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -abstract type RNumber{T<:ReactantPrimitive} <: Number end +""" + is_reactant_primitive(::Type{T}) + +Returns `true` if `T` is a primitive type supported by Reactant. +""" +is_reactant_primitive(::Type{<:ReactantPrimitive}) = true +is_reactant_primitive(::Type) = false + +abstract type RNumber{T} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end @@ -97,6 +105,7 @@ mutable struct TracedRNumber{T} <: RNumber{T} function TracedRNumber{T}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} ) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == () end @@ -114,6 +123,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." shape = Tuple(shape) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" @@ -157,6 +167,11 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N} mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer + + function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T}(data) + end end @leaf ConcreteRNumber @@ -164,6 +179,11 @@ end mutable struct ConcreteRArray{T,N} <: RArray{T,N} data::XLA.AsyncBuffer shape::NTuple{N,Int} + + function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T,N}(data, shape) + end end @leaf ConcreteRArray diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 165a9c71e..b12a34a4a 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -7,7 +7,6 @@ using ..Reactant: Reactant, TracedRArray, TracedRNumber, - ReactantPrimitive, WrappedTracedRArray, AnyTracedRArray, AnyTracedRVector, @@ -16,7 +15,8 @@ using ..Reactant: ancestor, allowscalar, aos_to_soa, - unwrapped_eltype + unwrapped_eltype, + is_reactant_primitive using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore @@ -480,14 +480,16 @@ end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T<:ReactantPrimitive,N} +) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims -) where {T<:ReactantPrimitive,N} +) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end @@ -507,7 +509,7 @@ end # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) - fn = if bc.f isa Type && bc.f <: ReactantPrimitive + fn = if bc.f isa Type && is_reactant_primitive(bc.f) TracedUtils.TypeCast{bc.f}() else bc.f diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index aa319adcd..bbbb386f9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,14 +1,7 @@ module TracedRNumberOverrides using ..Reactant: - Reactant, - TracedRNumber, - TracedRArray, - ReactantPrimitive, - TracedUtils, - Ops, - MLIR, - unwrapped_eltype + Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 8802ed083..4f270932c 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -14,8 +14,8 @@ using ..Reactant: AnyTracedRArray, MissingTracedValue, OrderedIdDict, - ReactantPrimitive, - Ops + Ops, + is_reactant_primitive using ReactantCore: MissingTracedValue materialize_traced_array(x::TracedRArray) = x @@ -283,15 +283,21 @@ function __take_region(compiled_fn) return region end -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x +elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x -struct TypeCast{T<:ReactantPrimitive} <: Function end +struct TypeCast{T} <: Function + function TypeCast{T}() where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T}() + end +end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} return TracedUtils.promote_to(TracedRNumber{T}, x) end -function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive} +function elem_apply(::Type{T}, x::TracedRArray) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end diff --git a/src/Tracing.jl b/src/Tracing.jl index d07c5178c..180d69b44 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -84,13 +84,10 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) - if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive + if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST) TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) push!(subParms, TrT) - elseif wrapped_tracedarray && - i == 1 && - SST isa Type && - SST <: TracedRNumber{<:ReactantPrimitive} + elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) push!(subParms, TrT) else @@ -164,15 +161,17 @@ for T in ( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ReactantPrimitive}), + @nospecialize(T::Type{<:Number}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) ) - if Mode == ArrayToConcrete && T <: track_numbers - return ConcreteRNumber{T} - elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers - return TracedRNumber{T} + if is_reactant_primitive(T) + if Mode == ArrayToConcrete && T <: track_numbers + return ConcreteRNumber{T} + elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers + return TracedRNumber{T} + end end return T end @@ -395,7 +394,7 @@ Base.@nospecializeinfer function traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == ArrayToConcrete && T <: ReactantPrimitive + if mode == ArrayToConcrete && is_reactant_primitive(T) return ConcreteRArray{T,N} else return Array{traced_type_inner(T, seen, mode, track_numbers),N} @@ -904,7 +903,7 @@ function make_tracer( if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT)) return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), Val(mode), track_numbers) @@ -995,17 +994,21 @@ end @nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type) -) - return ConcreteRArray(x) + @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) +) where {T} + is_reactant_primitive(T) && return ConcreteRArray(x) + return @invoke to_rarray_internal(x::Any, track_numbers::Type) end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type) + @nospecialize(x::Number), @nospecialize(track_numbers::Type) ) - typeof(x) <: track_numbers && return ConcreteRNumber(x) - return x + if is_reactant_primitive(typeof(x)) + typeof(x) <: track_numbers && return ConcreteRNumber(x) + return x + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) end