From 89ef696077c6e9dacfeda828c898a4fd7f9d57e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 16:39:51 -0500 Subject: [PATCH 1/5] feat: introduce `is_reactant_primitive` to define additional primitives --- ext/ReactantCUDAExt.jl | 4 ++-- src/Overlay.jl | 6 +++--- src/Reactant.jl | 22 +++++++++++++++++++++- src/TracedRArray.jl | 12 +++++++----- src/TracedRNumber.jl | 9 +-------- src/TracedUtils.jl | 16 +++++++++++----- src/Tracing.jl | 39 +++++++++++++++++++++------------------ 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index e5f6e4fc3..a8095f763 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive + if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T) return Reactant.ConcreteRArray{T,N} else TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) @@ -795,7 +795,7 @@ function Reactant.make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive + if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT)) return seen[prev] = Reactant.ConcreteRArray(Array(prev)) end TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers) diff --git a/src/Overlay.jl b/src/Overlay.jl index 5d9b85c83..dbd7e7e5e 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T, dims) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 ) where {T} - if T <: ReactantPrimitive + if is_reactant_primitive(T) return TracedRandom.$(overload_randfun)(rng, T) end @warn "Reactant doesn't support sampling of $(T) with the current \ diff --git a/src/Reactant.jl b/src/Reactant.jl index 41f6ab929..13286afd6 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -58,7 +58,15 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -abstract type RNumber{T<:ReactantPrimitive} <: Number end +""" + is_reactant_primitive(::Type{T}) + +Returns `true` if `T` is a primitive type supported by Reactant. +""" +is_reactant_primitive(::Type{<:ReactantPrimitive}) = true +is_reactant_primitive(::Type) = false + +abstract type RNumber{T} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end @@ -97,6 +105,7 @@ mutable struct TracedRNumber{T} <: RNumber{T} function TracedRNumber{T}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} ) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == () end @@ -114,6 +123,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." shape = Tuple(shape) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" @@ -157,6 +167,11 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N} mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer + + function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T}(data) + end end @leaf ConcreteRNumber @@ -164,6 +179,11 @@ end mutable struct ConcreteRArray{T,N} <: RArray{T,N} data::XLA.AsyncBuffer shape::NTuple{N,Int} + + function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T,N}(data, shape) + end end @leaf ConcreteRArray diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 165a9c71e..b12a34a4a 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -7,7 +7,6 @@ using ..Reactant: Reactant, TracedRArray, TracedRNumber, - ReactantPrimitive, WrappedTracedRArray, AnyTracedRArray, AnyTracedRVector, @@ -16,7 +15,8 @@ using ..Reactant: ancestor, allowscalar, aos_to_soa, - unwrapped_eltype + unwrapped_eltype, + is_reactant_primitive using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore @@ -480,14 +480,16 @@ end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T<:ReactantPrimitive,N} +) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims -) where {T<:ReactantPrimitive,N} +) where {T,N} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end @@ -507,7 +509,7 @@ end # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) - fn = if bc.f isa Type && bc.f <: ReactantPrimitive + fn = if bc.f isa Type && is_reactant_primitive(bc.f) TracedUtils.TypeCast{bc.f}() else bc.f diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index aa319adcd..bbbb386f9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,14 +1,7 @@ module TracedRNumberOverrides using ..Reactant: - Reactant, - TracedRNumber, - TracedRArray, - ReactantPrimitive, - TracedUtils, - Ops, - MLIR, - unwrapped_eltype + Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 8802ed083..4f270932c 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -14,8 +14,8 @@ using ..Reactant: AnyTracedRArray, MissingTracedValue, OrderedIdDict, - ReactantPrimitive, - Ops + Ops, + is_reactant_primitive using ReactantCore: MissingTracedValue materialize_traced_array(x::TracedRArray) = x @@ -283,15 +283,21 @@ function __take_region(compiled_fn) return region end -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x +elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x -struct TypeCast{T<:ReactantPrimitive} <: Function end +struct TypeCast{T} <: Function + function TypeCast{T}() where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." + return new{T}() + end +end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} return TracedUtils.promote_to(TracedRNumber{T}, x) end -function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive} +function elem_apply(::Type{T}, x::TracedRArray) where {T} + @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." # Special Path to prevent going down a despecialized path return elem_apply(TypeCast{T}(), x) end diff --git a/src/Tracing.jl b/src/Tracing.jl index d07c5178c..180d69b44 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -84,13 +84,10 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) - if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive + if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST) TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) push!(subParms, TrT) - elseif wrapped_tracedarray && - i == 1 && - SST isa Type && - SST <: TracedRNumber{<:ReactantPrimitive} + elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) push!(subParms, TrT) else @@ -164,15 +161,17 @@ for T in ( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ReactantPrimitive}), + @nospecialize(T::Type{<:Number}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) ) - if Mode == ArrayToConcrete && T <: track_numbers - return ConcreteRNumber{T} - elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers - return TracedRNumber{T} + if is_reactant_primitive(T) + if Mode == ArrayToConcrete && T <: track_numbers + return ConcreteRNumber{T} + elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers + return TracedRNumber{T} + end end return T end @@ -395,7 +394,7 @@ Base.@nospecializeinfer function traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == ArrayToConcrete && T <: ReactantPrimitive + if mode == ArrayToConcrete && is_reactant_primitive(T) return ConcreteRArray{T,N} else return Array{traced_type_inner(T, seen, mode, track_numbers),N} @@ -904,7 +903,7 @@ function make_tracer( if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT)) return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), Val(mode), track_numbers) @@ -995,17 +994,21 @@ end @nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type) -) - return ConcreteRArray(x) + @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) +) where {T} + is_reactant_primitive(T) && return ConcreteRArray(x) + return @invoke to_rarray_internal(x::Any, track_numbers::Type) end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type) + @nospecialize(x::Number), @nospecialize(track_numbers::Type) ) - typeof(x) <: track_numbers && return ConcreteRNumber(x) - return x + if is_reactant_primitive(typeof(x)) + typeof(x) <: track_numbers && return ConcreteRNumber(x) + return x + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) end From 1a873d02b9af8bf7e157081818ebb22eaeb04366 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 17:29:03 -0500 Subject: [PATCH 2/5] refactor: move to a different api --- ext/ReactantCUDAExt.jl | 4 +-- src/Overlay.jl | 6 ++-- src/PrimitiveTypes.jl | 74 ++++++++++++++++++++++++++++++++++++++++++ src/Reactant.jl | 62 ++--------------------------------- src/TracedRArray.jl | 11 +++---- src/TracedUtils.jl | 14 ++------ src/Tracing.jl | 34 ++++++++----------- 7 files changed, 102 insertions(+), 103 deletions(-) create mode 100644 src/PrimitiveTypes.jl diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index a8095f763..e5f6e4fc3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T) + if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive return Reactant.ConcreteRArray{T,N} else TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) @@ -795,7 +795,7 @@ function Reactant.make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT)) + if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive return seen[prev] = Reactant.ConcreteRArray(Array(prev)) end TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers) diff --git a/src/Overlay.jl b/src/Overlay.jl index dbd7e7e5e..5d9b85c83 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dims) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer... ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...) end @warn "Reactant doesn't support sampling of $(T) with the current \ @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp) @reactant_overlay @noinline function Random.$(randfun)( rng::AbstractRNG, ::Type{T}=Float64 ) where {T} - if is_reactant_primitive(T) + if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) end @warn "Reactant doesn't support sampling of $(T) with the current \ diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl new file mode 100644 index 000000000..ea3b7e2c0 --- /dev/null +++ b/src/PrimitiveTypes.jl @@ -0,0 +1,74 @@ +# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of +# types, users can use these to convert their types to the primitive types supported by +# Reactant. +struct F8E5M2{T} <: AbstractFloat + val::T +end + +struct F8E4M3FN{T} <: AbstractFloat + val::T +end + +struct F8E4M3B11FNUZ{T} <: AbstractFloat + val::T +end + +struct F8E5M2FNUZ{T} <: AbstractFloat + val::T +end + +struct F8E4M3FNUZ{T} <: AbstractFloat + val::T +end + +# TODO: Quantized types + +const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ} + +@static if isdefined(Core, :BFloat16) + const ReactantFloat = Union{ + Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)... + } +else + const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} +end + +const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...} + +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} + +const ReactantComplexInt = Union{ + Complex{Int8}, + Complex{UInt8}, + Complex{Int16}, + Complex{UInt16}, + Complex{Int32}, + Complex{UInt32}, + Complex{Int64}, + Complex{UInt64}, + Complex{Int128}, + Complex{UInt128}, +} + +const ReactantFloatInt = Union{ + Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... +} + +const ReactantPrimitive = Union{ + Bool, + Base.uniontypes(ReactantFloatInt)..., + Base.uniontypes(ReactantComplexInt)..., + Base.uniontypes(ReactantComplexFloat)..., +} + +""" + to_reactant_primitive(val) + +Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a +function that can be used to convert the value back to the original type. +""" +to_reactant_primitive(::T) where {T} = nothing, nothing + +for T in Base.uniontypes(ReactantPrimitive) + @eval to_reactant_primitive(val::$T) = val, identity +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 13286afd6..dc737c8dc 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -18,55 +18,9 @@ using Enzyme struct ReactantABI <: Enzyme.EnzymeCore.ABI end -@static if isdefined(Core, :BFloat16) - const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64} -else - const ReactantFloat = Union{Float16,Float32,Float64} -end - -@static if isdefined(Core, :BFloat16) - const ReactantComplexFloat = Union{ - Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64} - } -else - const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}} -end - -const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} - -const ReactantComplexInt = Union{ - Complex{Int8}, - Complex{UInt8}, - Complex{Int16}, - Complex{UInt16}, - Complex{Int32}, - Complex{UInt32}, - Complex{Int64}, - Complex{UInt64}, - Complex{Int128}, - Complex{UInt128}, -} - -const ReactantFloatInt = Union{ - Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... -} +include("PrimitiveTypes.jl") -const ReactantPrimitive = Union{ - Bool, - Base.uniontypes(ReactantFloatInt)..., - Base.uniontypes(ReactantComplexInt)..., - Base.uniontypes(ReactantComplexFloat)..., -} - -""" - is_reactant_primitive(::Type{T}) - -Returns `true` if `T` is a primitive type supported by Reactant. -""" -is_reactant_primitive(::Type{<:ReactantPrimitive}) = true -is_reactant_primitive(::Type) = false - -abstract type RNumber{T} <: Number end +abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type RArray{T,N} <: AbstractArray{T,N} end @@ -105,7 +59,6 @@ mutable struct TracedRNumber{T} <: RNumber{T} function TracedRNumber{T}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} ) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == () end @@ -123,7 +76,6 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} function TracedRArray{T,N}( paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape ) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." shape = Tuple(shape) if !isnothing(mlir_data) @assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))" @@ -167,11 +119,6 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N} mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer - - function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T}(data) - end end @leaf ConcreteRNumber @@ -179,11 +126,6 @@ end mutable struct ConcreteRArray{T,N} <: RArray{T,N} data::XLA.AsyncBuffer shape::NTuple{N,Int} - - function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T,N}(data, shape) - end end @leaf ConcreteRArray diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index b12a34a4a..f71ea995c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -15,8 +15,7 @@ using ..Reactant: ancestor, allowscalar, aos_to_soa, - unwrapped_eltype, - is_reactant_primitive + unwrapped_eltype using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore @@ -480,16 +479,14 @@ end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims -) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims -) where {T,N} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." +) where {T<:Reactant.ReactantPrimitive,N} @assert N isa Int return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end @@ -509,7 +506,7 @@ end # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle}) - fn = if bc.f isa Type && is_reactant_primitive(bc.f) + fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive TracedUtils.TypeCast{bc.f}() else bc.f diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 4f270932c..97ced6be6 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -14,8 +14,7 @@ using ..Reactant: AnyTracedRArray, MissingTracedValue, OrderedIdDict, - Ops, - is_reactant_primitive + Ops using ReactantCore: MissingTracedValue materialize_traced_array(x::TracedRArray) = x @@ -285,20 +284,13 @@ end elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x -struct TypeCast{T} <: Function - function TypeCast{T}() where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - return new{T}() - end -end +struct TypeCast{T <: Reactant.ReactantPrimitive} <: Function end function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} return TracedUtils.promote_to(TracedRNumber{T}, x) end -function elem_apply(::Type{T}, x::TracedRArray) where {T} - @assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant." - # Special Path to prevent going down a despecialized path +function elem_apply(::Type{T}, x::TracedRArray) where {T <: Reactant.ReactantPrimitive} return elem_apply(TypeCast{T}(), x) end diff --git a/src/Tracing.jl b/src/Tracing.jl index 180d69b44..6eb085c42 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -84,7 +84,7 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) - if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST) + if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber @@ -161,17 +161,15 @@ for T in ( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Number}), + @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type) ) - if is_reactant_primitive(T) - if Mode == ArrayToConcrete && T <: track_numbers - return ConcreteRNumber{T} - elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers - return TracedRNumber{T} - end + if Mode == ArrayToConcrete && T <: track_numbers + return ConcreteRNumber{T} + elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers + return TracedRNumber{T} end return T end @@ -394,7 +392,7 @@ Base.@nospecializeinfer function traced_type_inner( ) T = eltype(A) N = ndims(A) - if mode == ArrayToConcrete && is_reactant_primitive(T) + if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive return ConcreteRArray{T,N} else return Array{traced_type_inner(T, seen, mode, track_numbers),N} @@ -903,7 +901,7 @@ function make_tracer( if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT)) + if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), Val(mode), track_numbers) @@ -994,21 +992,17 @@ end @nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) -) where {T} - is_reactant_primitive(T) && return ConcreteRArray(x) - return @invoke to_rarray_internal(x::Any, track_numbers::Type) + @nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type) +) + return ConcreteRArray(x) end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) ) = x @inline function to_rarray_internal( - @nospecialize(x::Number), @nospecialize(track_numbers::Type) + @nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type) ) - if is_reactant_primitive(typeof(x)) - typeof(x) <: track_numbers && return ConcreteRNumber(x) - return x - end - return @invoke to_rarray_internal(x::Any, track_numbers::Type) + typeof(x) <: track_numbers && return ConcreteRNumber(x) + return x end From 028f639bd463146d04e694f0c8fff6f83a6f14f9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 20:09:46 -0500 Subject: [PATCH 3/5] feat: define primitive types --- src/PrimitiveTypes.jl | 60 ++++++++++++++-------------------- src/Tracing.jl | 18 +++++++++- src/XLA.jl | 7 ++++ src/mlir/IR/IR.jl | 1 + src/mlir/IR/Type.jl | 76 +++++++++++++++++++++++++++++++++++++++++++ src/mlir/MLIR.jl | 2 ++ 6 files changed, 127 insertions(+), 37 deletions(-) diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl index ea3b7e2c0..b547156ad 100644 --- a/src/PrimitiveTypes.jl +++ b/src/PrimitiveTypes.jl @@ -1,30 +1,33 @@ # These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of # types, users can use these to convert their types to the primitive types supported by # Reactant. -struct F8E5M2{T} <: AbstractFloat - val::T -end +for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) + @eval begin + primitive type $(T) <: AbstractFloat 8 end -struct F8E4M3FN{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16 + Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16 -struct F8E4M3B11FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32 + Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32 -struct F8E5M2FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64 + Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64 -struct F8E4M3FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T) + Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T) -# TODO: Quantized types + @static if isdefined(Core, :BFloat16) + Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 + Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16 + end + end +end const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ} +# TODO: Quantized types + @static if isdefined(Core, :BFloat16) const ReactantFloat = Union{ Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)... @@ -37,18 +40,7 @@ const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(Reactant const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} -const ReactantComplexInt = Union{ - Complex{Int8}, - Complex{UInt8}, - Complex{Int16}, - Complex{UInt16}, - Complex{Int32}, - Complex{UInt32}, - Complex{Int64}, - Complex{UInt64}, - Complex{Int128}, - Complex{UInt128}, -} +const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...} const ReactantFloatInt = Union{ Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... @@ -61,14 +53,10 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -""" - to_reactant_primitive(val) - -Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a -function that can be used to convert the value back to the original type. -""" -to_reactant_primitive(::T) where {T} = nothing, nothing +to_reactant_primitive(v::T) where {T} = reinterpret(reactant_primitive(T), v) +reactant_primitive(::Type{T}) where {T} = nothing for T in Base.uniontypes(ReactantPrimitive) - @eval to_reactant_primitive(val::$T) = val, identity + @eval to_reactant_primitive(val::$T) = val + @eval reactant_primitive(::Type{$T}) = $T end diff --git a/src/Tracing.jl b/src/Tracing.jl index 6eb085c42..9616b0130 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -531,7 +531,7 @@ Base.@assume_effects :total @inline function traced_type( cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - return res1 = traced_type_inner(T, cache, mode, track_numbers) + return traced_type_inner(T, cache, mode, track_numbers) end abstract type TracedTypeException <: Exception end @@ -996,6 +996,14 @@ end ) return ConcreteRArray(x) end +@inline function to_rarray_internal( + @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) +) where {T<:Number} + if reactant_primitive(T) !== nothing + return ConcreteRArray(to_reactant_primitive.(x)) + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) +end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) @@ -1006,3 +1014,11 @@ end typeof(x) <: track_numbers && return ConcreteRNumber(x) return x end +@inline function to_rarray_internal( + @nospecialize(x::Number), @nospecialize(track_numbers::Type) +) + if reactant_primitive(typeof(x)) !== nothing + return ConcreteRArray(to_reactant_primitive(x)) + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) +end diff --git a/src/XLA.jl b/src/XLA.jl index 3aaaf87c1..1b3c2b951 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -1,5 +1,6 @@ module XLA +import ..Reactant import ...MLIR const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75) @@ -356,6 +357,12 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 +@inline primitive_type(::Type{Reactant.F8E5M2}) = 19 +@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20 +@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23 +@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24 +@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25 + @static if isdefined(Core, :BFloat16) @inline primitive_type(::Type{Core.BFloat16}) = 16 end diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 044f27e5a..8da48846f 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -1,5 +1,6 @@ module IR +using ..Reactant using ..API # do not export `Type`, as it is already defined in Core diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index bd44b4d6f..2b09cfcef 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -191,6 +191,51 @@ Creates a f64 type in the given context. The type is owned by the context. """ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet(context)) +""" + Type(::Core.Type{Reactant.F8E5M2}; context=context()) + +Creates a f8e5m2 type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context()) + return Type(API.mlirFloat8E5M2TypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FN}; context=context()) + +Creates a f8e4m3fn type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context=context()) + +Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E5M2FNUZ}; context=context()) + +Creates a f8e5m2fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E5M2FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FNUZ}; context=context()) + +Creates a f8e4m3fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNTypeGet(context)) +end + """ isf8e5m2(type) @@ -205,6 +250,27 @@ Checks whether the given type is an f8E4M3FN type. """ isf8e4m3fn(type::Type) = API.mlirTypeIsAFloat8E4M3FN(type) +""" + isf8e4m3b11fnuz(type) + +Checks whether the given type is an f8E4M3B11FNUZ type. +""" +isf8e4m3b11fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3B11FNUZ(type) + +""" + isf8e5m2fnuz(type) + +Checks whether the given type is an f8E5M2FNUZ type. +""" +isf8e5m2fnuz(type::Type) = API.mlirTypeIsAFloat8E5M2FNUZ(type) + +""" + isf8e4m3fnuz(type) + +Checks whether the given type is an f8E4M3FNUZ type. +""" +isf8e4m3fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3FNUZ(type) + """ isbf16(type) @@ -738,6 +804,16 @@ function julia_type(type::Type) Float32 elseif isf64(type) Float64 + elseif isf8e5m2(type) + Reactant.F8E5M2 + elseif isf8e4m3fn(type) + Reactant.F8E4M3FN + elseif isf8e4m3b11fnuz(type) + Reactant.F8E4M3B11FNUZ + elseif isf8e5m2fnuz(type) + Reactant.F8E5M2FNUZ + elseif isf8e4m3fnuz(type) + Reactant.F8E4M3FNUZ elseif isnone(type) Nothing elseif iscomplex(type) diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index 6bbf3cad4..aec0ac27f 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -1,5 +1,7 @@ module MLIR +using ..Reactant + module API using CEnum using Preferences From e60e7cfc6c872fe529c485486ed0481dce86e80b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 21:18:04 -0500 Subject: [PATCH 4/5] feat: support custom user types --- src/Ops.jl | 19 +++++++++++++++++++ src/PrimitiveTypes.jl | 32 ++++++++++++++++++-------------- src/TracedRNumber.jl | 16 +++++++++------- src/XLA.jl | 10 +++++----- src/mlir/IR/Type.jl | 10 +++++----- 5 files changed, 56 insertions(+), 31 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index b953e1c59..46283fcde 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -76,6 +76,25 @@ end return TracedRArray{T,N}((), res, size(x)) end +## This is somewhat a hack because I can't seem to find the corresponding mlir +## DenseElementsAttribute functions (also our optimizations will run a pass converting this +## to a single operation) +for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) + @eval @noinline function constant( + x::DenseArray{<:Reactant.$(T),N}; + location=mlir_stacktrace("constant", @__FILE__, @__LINE__), + ) where {N} + value = MLIR.IR.DenseElementsAttribute( + map(Float16 ∘ Base.Fix2(getproperty, :val), x) + ) + output = mlir_type(TracedRArray{Float16,N}, size(x)) + res = MLIR.IR.result(stablehlo.constant(; output, value, location)) + return convert( + TracedRArray{eltype(x),N}, TracedRArray{Float16,N}((), res, size(x)); location + ) + end +end + @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl index b547156ad..a60c4a62d 100644 --- a/src/PrimitiveTypes.jl +++ b/src/PrimitiveTypes.jl @@ -1,25 +1,29 @@ +# The types listed in this file are the ones present in StableHLO specification. + # These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of # types, users can use these to convert their types to the primitive types supported by # Reactant. for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) @eval begin - primitive type $(T) <: AbstractFloat 8 end + struct $(T){inT} <: AbstractFloat + val::inT + end - Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16 - Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float16}) = Float16 + Base.promote_rule(::Type{Float16}, ::Type{<:$(T)}) = Float16 - Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32 - Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float32}) = Float32 + Base.promote_rule(::Type{Float32}, ::Type{<:$(T)}) = Float32 - Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64 - Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float64}) = Float64 + Base.promote_rule(::Type{Float64}, ::Type{<:$(T)}) = Float64 - Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T) - Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T) + Base.promote_rule(::Type{<:$(T){inT}}, ::Type{<:Integer}) where {inT} = $(T){inT} + Base.promote_rule(::Type{<:Integer}, ::Type{<:$(T){inT}}) where {inT} = $(T){inT} @static if isdefined(Core, :BFloat16) - Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 - Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16 + Base.promote_rule(::Type{<:$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 + Base.promote_rule(::Type{Core.BFloat16}, ::Type{<:$(T)}) = Core.BFloat16 end end end @@ -36,9 +40,9 @@ else const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} end -const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...} +const ReactantComplexFloat = Union{Complex{Float32}, Complex{Float64}} -const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...} @@ -53,7 +57,7 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -to_reactant_primitive(v::T) where {T} = reinterpret(reactant_primitive(T), v) +to_reactant_primitive(v::T) where {T} = reactant_primitive(T)(v) reactant_primitive(::Type{T}) where {T} = nothing for T in Base.uniontypes(ReactantPrimitive) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index bbbb386f9..d81fe1c97 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -12,7 +12,7 @@ Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T)) Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ()) -function Base.eps(::Type{TracedRNumber{T}}) where {T} +function Base.eps(::Type{<:TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end @@ -36,24 +36,26 @@ end Base.only(A::TracedRNumber{T}) where {T} = A -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} +function Base.promote_rule( + ::Type{<:TracedRNumber{T}}, ::Type{<:TracedRNumber{S}} +) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end # Bool has special promotion rules in Base -function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T} +function Base.promote_rule(::Type{Bool}, ::Type{<:TracedRNumber{T}}) where {T} return TracedRNumber{T} end -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T} +function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{Bool}) where {T} return TracedRNumber{T} end -function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} +function Base.promote_rule(::Type{T}, ::Type{<:TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S} +function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{S}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end @@ -67,7 +69,7 @@ function TracedRNumber{T}(x::Number) where {T} return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x) end -function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} +function TracedUtils.promote_to(::Type{<:TracedRNumber{T}}, rhs) where {T} if rhs isa TracedRNumber rhs isa TracedRNumber{T} && return rhs return Ops.convert(TracedRNumber{T}, rhs) diff --git a/src/XLA.jl b/src/XLA.jl index 1b3c2b951..ef217d2f5 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -357,11 +357,11 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 -@inline primitive_type(::Type{Reactant.F8E5M2}) = 19 -@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20 -@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23 -@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24 -@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25 +@inline primitive_type(::Type{<:Reactant.F8E5M2}) = 19 +@inline primitive_type(::Type{<:Reactant.F8E4M3FN}) = 20 +@inline primitive_type(::Type{<:Reactant.F8E4M3B11FNUZ}) = 23 +@inline primitive_type(::Type{<:Reactant.F8E5M2FNUZ}) = 24 +@inline primitive_type(::Type{<:Reactant.F8E4M3FNUZ}) = 25 @static if isdefined(Core, :BFloat16) @inline primitive_type(::Type{Core.BFloat16}) = 16 diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index 2b09cfcef..4d4508dde 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -196,7 +196,7 @@ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet Creates a f8e5m2 type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E5M2}; context::Context=context()) return Type(API.mlirFloat8E5M2TypeGet(context)) end @@ -205,7 +205,7 @@ end Creates a f8e4m3fn type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3FN}; context::Context=context()) return Type(API.mlirFloat8E4M3FNTypeGet(context)) end @@ -214,7 +214,7 @@ end Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3B11FNUZ}; context::Context=context()) return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context)) end @@ -223,7 +223,7 @@ end Creates a f8e5m2fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E5M2FNUZ}; context::Context=context()) return Type(API.mlirFloat8E5M2FNUZTypeGet(context)) end @@ -232,7 +232,7 @@ end Creates a f8e4m3fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3FNUZ}; context::Context=context()) return Type(API.mlirFloat8E4M3FNTypeGet(context)) end From 37039305d918e75b0bd49e3bad5169745ac78bbe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 21:21:34 -0500 Subject: [PATCH 5/5] chore: run formatter --- src/PrimitiveTypes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl index a60c4a62d..77b2bdcf2 100644 --- a/src/PrimitiveTypes.jl +++ b/src/PrimitiveTypes.jl @@ -40,7 +40,7 @@ else const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} end -const ReactantComplexFloat = Union{Complex{Float32}, Complex{Float64}} +const ReactantComplexFloat = Union{Complex{Float32},Complex{Float64}} const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}