diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index a8095f763..e5f6e4fc3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T) + if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive return Reactant.ConcreteRArray{T,N} else TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) @@ -795,7 +795,7 @@ function Reactant.make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT)) + if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive return seen[prev] = Reactant.ConcreteRArray(Array(prev)) end TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers) diff --git a/src/Overlay.jl b/src/Overlay.jl index dbd7e7e5e..5d9b85c83 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dims) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) end @warn "Reactant doesn't support sampling of $(T) with the current \ diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl new file mode 100644 index 000000000..ea3b7e2c0 --- /dev/null +++ b/src/PrimitiveTypes.jl @@ -0,0 +1,74 @@ +# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of +# types, users can use these to convert their types to the primitive types supported by +# Reactant. +struct F8E5M2{T} <: AbstractFloat + val::T +end + +struct F8E4M3FN{T} <: AbstractFloat + val::T +end + +struct F8E4M3B11FNUZ{T} <: AbstractFloat + val::T +end + +struct F8E5M2FNUZ{T} <: AbstractFloat + val::T +end + +struct F8E4M3FNUZ{T} <: AbstractFloat + val::T +end + +# TODO: Quantized types + +const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ} + +@static if isdefined(Core, :BFloat16) + const ReactantFloat = Union{ + Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)... + } +else + const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} +end + +const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...} + +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} + +const ReactantComplexInt = Union{ + Complex{Int8}, + Complex{UInt8}, + Complex{Int16}, + Complex{UInt16}, + Complex{Int32}, + Complex{UInt32}, + Complex{Int64}, + Complex{UInt64}, + Complex{Int128}, + Complex{UInt128}, +} + +const ReactantFloatInt = Union{ + Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... +} + +const ReactantPrimitive = Union{ + Bool, + Base.uniontypes(ReactantFloatInt)..., + Base.uniontypes(ReactantComplexInt)..., + Base.uniontypes(ReactantComplexFloat)..., +} + +""" + to_reactant_primitive(val) + +Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a +function that can be used to convert the value back to the original type. +""" +to_reactant_primitive(::T) where {T} = nothing, nothing + +for T in Base.uniontypes(ReactantPrimitive) + @eval to_reactant_primitive(val::$T) = val, identity +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 13286afd6..dc737c8dc 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -18,55 +18,9 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end -@static if isdefined(Core, :BFloat16) - const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64} -else - const ReactantFloat = Union{Float16,Float32,Float64} -end - -@static if isdefined(Core, :BFloat16) - const ReactantComplexFloat = Union{ - Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64} - } -else - const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}} -end - -const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} - -const ReactantComplexInt = Union{ - Complex{Int8}, - Complex{UInt8}, - Complex{Int16}, - Complex{UInt16}, - Complex{Int32}, - Complex{UInt32}, - Complex{Int64}, - Complex{UInt64}, - Complex{Int128}, - Complex{UInt128}, -} - -const ReactantFloatInt = Union{ - Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... -} +include("PrimitiveTypes.jl") -const ReactantPrimitive = Union{ - Bool, - Base.uniontypes(ReactantFloatInt)..., - Base.uniontypes(ReactantComplexInt)..., - Base.uniontypes(ReactantComplexFloat)..., -} - -""" - is_reactant_primitive(::Type{T}) - -Returns `true` if `T` is a primitive type supported by Reactant. -""" -is_reactant_primitive(::Type{<:ReactantPrimitive}) = true -is_reactant_primitive(::Type) = false - -abstract type RNumber{T} <: Number end +abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end @@ -105,7 +59,6 @@ mutable struct TracedRNumber{T} <: RNumber{T} function TracedRNumber{T}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} ) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == () end @@ -123,7 +76,6 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." shape = Tuple(shape) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" @@ -167,11 +119,6 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N} mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer - - function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T}(data) - end end @leaf ConcreteRNumber @@ -179,11 +126,6 @@ end mutable struct ConcreteRArray{T,N} <: RArray{T,N} data::XLA.AsyncBuffer shape::NTuple{N,Int} - - function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T,N}(data, shape) - end end @leaf ConcreteRArray diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index b12a34a4a..f71ea995c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -15,8 +15,7 @@ using ..Reactant: ancestor, allowscalar, aos_to_soa, - unwrapped_eltype, - is_reactant_primitive + unwrapped_eltype using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore @@ -480,16 +479,14 @@ end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims -) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end @@ -509,7 +506,7 @@ end # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) - fn = if bc.f isa Type && is_reactant_primitive(bc.f) + fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive TracedUtils.TypeCast{bc.f}() else bc.f diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4f270932c..97ced6be6 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -14,8 +14,7 @@ using ..Reactant: AnyTracedRArray, MissingTracedValue, OrderedIdDict, - Ops, - is_reactant_primitive + Ops using ReactantCore: MissingTracedValue materialize_traced_array(x::TracedRArray) = x @@ -285,20 +284,13 @@ end elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x -struct TypeCast{T} <: Function - function TypeCast{T}() where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T}() - end -end +struct TypeCast{T <: Reactant.ReactantPrimitive} <: Function end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} return TracedUtils.promote_to(TracedRNumber{T}, x) end -function elem_apply(::Type{T}, x::TracedRArray) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - # Special Path to prevent going down a despecialized path +function elem_apply(::Type{T}, x::TracedRArray) where {T <: Reactant.ReactantPrimitive} return elem_apply(TypeCast{T}(), x) end diff --git a/src/Tracing.jl b/src/Tracing.jl index 180d69b44..6eb085c42 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -84,7 +84,7 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) - if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST) + if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber @@ -161,17 +161,15 @@ for T in ( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Number}), + @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) ) - if is_reactant_primitive(T) - if Mode == ArrayToConcrete && T <: track_numbers - return ConcreteRNumber{T} - elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers - return TracedRNumber{T} - end + if Mode == ArrayToConcrete && T <: track_numbers + return ConcreteRNumber{T} + elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers + return TracedRNumber{T} end return T end @@ -394,7 +392,7 @@ Base.@nospecializeinfer function traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == ArrayToConcrete && is_reactant_primitive(T) + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive return ConcreteRArray{T,N} else return Array{traced_type_inner(T, seen, mode, track_numbers),N} @@ -903,7 +901,7 @@ function make_tracer( if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT)) + if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), Val(mode), track_numbers) @@ -994,21 +992,17 @@ end @nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) -) where {T} - is_reactant_primitive(T) && return ConcreteRArray(x) - return @invoke to_rarray_internal(x::Any, track_numbers::Type) + @nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type) +) + return ConcreteRArray(x) end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Number), @nospecialize(track_numbers::Type) + @nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type) ) - if is_reactant_primitive(typeof(x)) - typeof(x) <: track_numbers && return ConcreteRNumber(x) - return x - end - return @invoke to_rarray_internal(x::Any, track_numbers::Type) + typeof(x) <: track_numbers && return ConcreteRNumber(x) + return x end