From d68cda9dadba95cd8a173b1abff2cc0ffafc5f5f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 20:09:46 -0500 Subject: [PATCH] feat: define primitive types --- src/PrimitiveTypes.jl | 62 +++++++++++++++-------------------- src/Tracing.jl | 18 +++++++++- src/XLA.jl | 7 ++++ src/mlir/IR/IR.jl | 1 + src/mlir/IR/Type.jl | 76 +++++++++++++++++++++++++++++++++++++++++++ src/mlir/MLIR.jl | 2 ++ 6 files changed, 129 insertions(+), 37 deletions(-) diff --git a/src/PrimitiveTypes.jl b/src/PrimitiveTypes.jl index ea3b7e2c0..f81c22b3e 100644 --- a/src/PrimitiveTypes.jl +++ b/src/PrimitiveTypes.jl @@ -1,30 +1,33 @@ # These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of # types, users can use these to convert their types to the primitive types supported by # Reactant. -struct F8E5M2{T} <: AbstractFloat - val::T -end +for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ) + @eval begin + primitive type $(T) <: AbstractFloat 8 end -struct F8E4M3FN{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16 + Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16 -struct F8E4M3B11FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32 + Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32 -struct F8E5M2FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64 + Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64 -struct F8E4M3FNUZ{T} <: AbstractFloat - val::T -end + Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T) + Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T) -# TODO: Quantized types + @static if isdefined(Core, :BFloat16) + Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16 + Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16 + end + end +end const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ} +# TODO: Quantized types + @static if isdefined(Core, :BFloat16) const ReactantFloat = Union{ Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)... @@ -37,18 +40,7 @@ const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(Reactant const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128} -const ReactantComplexInt = Union{ - Complex{Int8}, - Complex{UInt8}, - Complex{Int16}, - Complex{UInt16}, - Complex{Int32}, - Complex{UInt32}, - Complex{Int64}, - Complex{UInt64}, - Complex{Int128}, - Complex{UInt128}, -} +const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...} const ReactantFloatInt = Union{ Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)... @@ -61,14 +53,12 @@ const ReactantPrimitive = Union{ Base.uniontypes(ReactantComplexFloat)..., } -""" - to_reactant_primitive(val) - -Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a -function that can be used to convert the value back to the original type. -""" -to_reactant_primitive(::T) where {T} = nothing, nothing +function to_reactant_primitive(v::T) where {T} + return reinterpret(to_reactant_primitive_type(T), v) +end +to_reactant_primitive_type(::Type{T}) where {T} = nothing for T in Base.uniontypes(ReactantPrimitive) - @eval to_reactant_primitive(val::$T) = val, identity + @eval to_reactant_primitive(val::$T) = val + @eval to_reactant_primitive_type(::Type{$T}) = $T end diff --git a/src/Tracing.jl b/src/Tracing.jl index 6eb085c42..73c72b560 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -531,7 +531,7 @@ Base.@assume_effects :total @inline function traced_type( cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - return res1 = traced_type_inner(T, cache, mode, track_numbers) + return traced_type_inner(T, cache, mode, track_numbers) end abstract type TracedTypeException <: Exception end @@ -996,6 +996,14 @@ end ) return ConcreteRArray(x) end +@inline function to_rarray_internal( + @nospecialize(x::Array{T}), @nospecialize(track_numbers::Type) +) where {T<:Number} + if to_reactant_primitive_type(T) !== nothing + return ConcreteRArray(to_reactant_primitive.(x)) + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) +end @inline to_rarray_internal( @nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type) @@ -1006,3 +1014,11 @@ end typeof(x) <: track_numbers && return ConcreteRNumber(x) return x end +@inline function to_rarray_internal( + @nospecialize(x::Number), @nospecialize(track_numbers::Type) +) + if to_reactant_primitive_type(typeof(x)) !== nothing + return ConcreteRArray(to_reactant_primitive(x)) + end + return @invoke to_rarray_internal(x::Any, track_numbers::Type) +end diff --git a/src/XLA.jl b/src/XLA.jl index 3aaaf87c1..1b3c2b951 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -1,5 +1,6 @@ module XLA +import ..Reactant import ...MLIR const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75) @@ -356,6 +357,12 @@ end @inline primitive_type(::Type{Float16}) = 10 @inline primitive_type(::Type{Float32}) = 11 +@inline primitive_type(::Type{Reactant.F8E5M2}) = 19 +@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20 +@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23 +@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24 +@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25 + @static if isdefined(Core, :BFloat16) @inline primitive_type(::Type{Core.BFloat16}) = 16 end diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 044f27e5a..8da48846f 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -1,5 +1,6 @@ module IR +using ..Reactant using ..API # do not export `Type`, as it is already defined in Core diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index bd44b4d6f..2b09cfcef 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -191,6 +191,51 @@ Creates a f64 type in the given context. The type is owned by the context. """ Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet(context)) +""" + Type(::Core.Type{Reactant.F8E5M2}; context=context()) + +Creates a f8e5m2 type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context()) + return Type(API.mlirFloat8E5M2TypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FN}; context=context()) + +Creates a f8e4m3fn type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context=context()) + +Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E5M2FNUZ}; context=context()) + +Creates a f8e5m2fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E5M2FNUZTypeGet(context)) +end + +""" + Type(::Core.Type{Reactant.F8E4M3FNUZ}; context=context()) + +Creates a f8e4m3fnuz type in the given context. The type is owned by the context. +""" +function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context()) + return Type(API.mlirFloat8E4M3FNTypeGet(context)) +end + """ isf8e5m2(type) @@ -205,6 +250,27 @@ Checks whether the given type is an f8E4M3FN type. """ isf8e4m3fn(type::Type) = API.mlirTypeIsAFloat8E4M3FN(type) +""" + isf8e4m3b11fnuz(type) + +Checks whether the given type is an f8E4M3B11FNUZ type. +""" +isf8e4m3b11fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3B11FNUZ(type) + +""" + isf8e5m2fnuz(type) + +Checks whether the given type is an f8E5M2FNUZ type. +""" +isf8e5m2fnuz(type::Type) = API.mlirTypeIsAFloat8E5M2FNUZ(type) + +""" + isf8e4m3fnuz(type) + +Checks whether the given type is an f8E4M3FNUZ type. +""" +isf8e4m3fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3FNUZ(type) + """ isbf16(type) @@ -738,6 +804,16 @@ function julia_type(type::Type) Float32 elseif isf64(type) Float64 + elseif isf8e5m2(type) + Reactant.F8E5M2 + elseif isf8e4m3fn(type) + Reactant.F8E4M3FN + elseif isf8e4m3b11fnuz(type) + Reactant.F8E4M3B11FNUZ + elseif isf8e5m2fnuz(type) + Reactant.F8E5M2FNUZ + elseif isf8e4m3fnuz(type) + Reactant.F8E4M3FNUZ elseif isnone(type) Nothing elseif iscomplex(type) diff --git a/src/mlir/MLIR.jl b/src/mlir/MLIR.jl index 6bbf3cad4..aec0ac27f 100644 --- a/src/mlir/MLIR.jl +++ b/src/mlir/MLIR.jl @@ -1,5 +1,7 @@ module MLIR +using ..Reactant + module API using CEnum using Preferences