diff --git a/src/Ops.jl b/src/Ops.jl index b953e1c59..46283fcde 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -76,6 +76,25 @@ end return TracedRArray{T,N}((), res, size(x)) end +## This is somewhat a hack because I can't seem to find the corresponding mlir +## DenseElementsAttribute functions (also our optimizations will run a pass converting this +## to a single operation) +for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) + @eval @noinline function constant( + x::DenseArray{<:Reactant.$(T),N}; + location=mlir_stacktrace("constant", @__FILE__, @__LINE__), + ) where {N} + value = MLIR.IR.DenseElementsAttribute( + map(Float16 ∘ Base.Fix2(getproperty, :val), x) + ) + output = mlir_type(TracedRArray{Float16,N}, size(x)) + res = MLIR.IR.result(stablehlo.constant(; output, value, location)) + return convert( + TracedRArray{eltype(x),N}, TracedRArray{Float16,N}((), res, size(x)); location + ) + end +end + @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl index b547156ad..a60c4a62d 100644 --- a/src/PrimitiveTypes.jl +++ b/src/PrimitiveTypes.jl @@ -1,25 +1,29 @@ +# The types listed in this file are the ones present in StableHLO specification. + # These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of # types, users can use these to convert their types to the primitive types supported by # Reactant. for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) @eval begin - primitive type $(T) <: AbstractFloat 8 end + struct $(T){inT} <: AbstractFloat + val::inT + end - Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16 - Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float16}) = Float16 + Base.promote_rule(::Type{Float16}, ::Type{<:$(T)}) = Float16 - Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32 - Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float32}) = Float32 + Base.promote_rule(::Type{Float32}, ::Type{<:$(T)}) = Float32 - Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64 - Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64 + Base.promote_rule(::Type{<:$(T)}, ::Type{Float64}) = Float64 + Base.promote_rule(::Type{Float64}, ::Type{<:$(T)}) = Float64 - Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T) - Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T) + Base.promote_rule(::Type{<:$(T){inT}}, ::Type{<:Integer}) where {inT} = $(T){inT} + Base.promote_rule(::Type{<:Integer}, ::Type{<:$(T){inT}}) where {inT} = $(T){inT} @static if isdefined(Core, :BFloat16) - Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 - Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16 + Base.promote_rule(::Type{<:$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 + Base.promote_rule(::Type{Core.BFloat16}, ::Type{<:$(T)}) = Core.BFloat16 end end end @@ -36,9 +40,9 @@ else const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...} end -const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...} +const ReactantComplexFloat = Union{Complex{Float32}, Complex{Float64}} -const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} +const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64} const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...} @@ -53,7 +57,7 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -to_reactant_primitive(v::T) where {T} = reinterpret(reactant_primitive(T), v) +to_reactant_primitive(v::T) where {T} = reactant_primitive(T)(v) reactant_primitive(::Type{T}) where {T} = nothing for T in Base.uniontypes(ReactantPrimitive) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index bbbb386f9..d81fe1c97 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -12,7 +12,7 @@ Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T)) Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ()) -function Base.eps(::Type{TracedRNumber{T}}) where {T} +function Base.eps(::Type{<:TracedRNumber{T}}) where {T} return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) end @@ -36,24 +36,26 @@ end Base.only(A::TracedRNumber{T}) where {T} = A -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} +function Base.promote_rule( + ::Type{<:TracedRNumber{T}}, ::Type{<:TracedRNumber{S}} +) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end # Bool has special promotion rules in Base -function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T} +function Base.promote_rule(::Type{Bool}, ::Type{<:TracedRNumber{T}}) where {T} return TracedRNumber{T} end -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T} +function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{Bool}) where {T} return TracedRNumber{T} end -function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} +function Base.promote_rule(::Type{T}, ::Type{<:TracedRNumber{S}}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end -function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S} +function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{S}) where {T,S} return TracedRNumber{Base.promote_type(T, S)} end @@ -67,7 +69,7 @@ function TracedRNumber{T}(x::Number) where {T} return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x) end -function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} +function TracedUtils.promote_to(::Type{<:TracedRNumber{T}}, rhs) where {T} if rhs isa TracedRNumber rhs isa TracedRNumber{T} && return rhs return Ops.convert(TracedRNumber{T}, rhs) diff --git a/src/XLA.jl b/src/XLA.jl index 1b3c2b951..ef217d2f5 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -357,11 +357,11 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 -@inline primitive_type(::Type{Reactant.F8E5M2}) = 19 -@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20 -@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23 -@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24 -@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25 +@inline primitive_type(::Type{<:Reactant.F8E5M2}) = 19 +@inline primitive_type(::Type{<:Reactant.F8E4M3FN}) = 20 +@inline primitive_type(::Type{<:Reactant.F8E4M3B11FNUZ}) = 23 +@inline primitive_type(::Type{<:Reactant.F8E5M2FNUZ}) = 24 +@inline primitive_type(::Type{<:Reactant.F8E4M3FNUZ}) = 25 @static if isdefined(Core, :BFloat16) @inline primitive_type(::Type{Core.BFloat16}) = 16 diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index 2b09cfcef..4d4508dde 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -196,7 +196,7 @@ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet Creates a f8e5m2 type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E5M2}; context::Context=context()) return Type(API.mlirFloat8E5M2TypeGet(context)) end @@ -205,7 +205,7 @@ end Creates a f8e4m3fn type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3FN}; context::Context=context()) return Type(API.mlirFloat8E4M3FNTypeGet(context)) end @@ -214,7 +214,7 @@ end Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3B11FNUZ}; context::Context=context()) return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context)) end @@ -223,7 +223,7 @@ end Creates a f8e5m2fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E5M2FNUZ}; context::Context=context()) return Type(API.mlirFloat8E5M2FNUZTypeGet(context)) end @@ -232,7 +232,7 @@ end Creates a f8e4m3fnuz type in the given context. The type is owned by the context. """ -function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context()) +function Type(::Core.Type{<:Reactant.F8E4M3FNUZ}; context::Context=context()) return Type(API.mlirFloat8E4M3FNTypeGet(context)) end