Skip to content

Commit

Permalink
feat: define primitive types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 23, 2025
1 parent 1a873d0 commit d68cda9
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 37 deletions.
62 changes: 26 additions & 36 deletions src/PrimitiveTypes.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of
# types, users can use these to convert their types to the primitive types supported by
# Reactant.
struct F8E5M2{T} <: AbstractFloat
val::T
end
for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ)
@eval begin
primitive type $(T) <: AbstractFloat 8 end

struct F8E4M3FN{T} <: AbstractFloat
val::T
end
Base.promote_rule(::Type{$(T)}, ::Type{Float16}) = Float16
Base.promote_rule(::Type{Float16}, ::Type{$(T)}) = Float16

struct F8E4M3B11FNUZ{T} <: AbstractFloat
val::T
end
Base.promote_rule(::Type{$(T)}, ::Type{Float32}) = Float32
Base.promote_rule(::Type{Float32}, ::Type{$(T)}) = Float32

struct F8E5M2FNUZ{T} <: AbstractFloat
val::T
end
Base.promote_rule(::Type{$(T)}, ::Type{Float64}) = Float64
Base.promote_rule(::Type{Float64}, ::Type{$(T)}) = Float64

struct F8E4M3FNUZ{T} <: AbstractFloat
val::T
end
Base.promote_rule(::Type{$(T)}, ::Type{<:Integer}) = $(T)
Base.promote_rule(::Type{<:Integer}, ::Type{$(T)}) = $(T)

# TODO: Quantized types
@static if isdefined(Core, :BFloat16)
Base.promote_rule(::Type{$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16
Base.promote_rule(::Type{Core.BFloat16}, ::Type{$(T)}) = Core.BFloat16
end
end
end

const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ}

# TODO: Quantized types

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{
Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)...
Expand All @@ -37,18 +40,7 @@ const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(Reactant

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}
const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
Expand All @@ -61,14 +53,12 @@ const ReactantPrimitive = Union{
Base.uniontypes(ReactantComplexFloat)...,
}

"""
to_reactant_primitive(val)
Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a
function that can be used to convert the value back to the original type.
"""
to_reactant_primitive(::T) where {T} = nothing, nothing
function to_reactant_primitive(v::T) where {T}
return reinterpret(to_reactant_primitive_type(T), v)
end
to_reactant_primitive_type(::Type{T}) where {T} = nothing

for T in Base.uniontypes(ReactantPrimitive)
@eval to_reactant_primitive(val::$T) = val, identity
@eval to_reactant_primitive(val::$T) = val
@eval to_reactant_primitive_type(::Type{$T}) = $T
end
18 changes: 17 additions & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ Base.@assume_effects :total @inline function traced_type(
cache = Dict{Type,Type}()
traced_type_cache[cache_key] = cache
end
return res1 = traced_type_inner(T, cache, mode, track_numbers)
return traced_type_inner(T, cache, mode, track_numbers)
end

abstract type TracedTypeException <: Exception end
Expand Down Expand Up @@ -996,6 +996,14 @@ end
)
return ConcreteRArray(x)
end
@inline function to_rarray_internal(
@nospecialize(x::Array{T}), @nospecialize(track_numbers::Type)
) where {T<:Number}
if to_reactant_primitive_type(T) !== nothing
return ConcreteRArray(to_reactant_primitive.(x))
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end

@inline to_rarray_internal(
@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)
Expand All @@ -1006,3 +1014,11 @@ end
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
end
@inline function to_rarray_internal(
@nospecialize(x::Number), @nospecialize(track_numbers::Type)
)
if to_reactant_primitive_type(typeof(x)) !== nothing
return ConcreteRArray(to_reactant_primitive(x))
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end
7 changes: 7 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module XLA

import ..Reactant
import ...MLIR

const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75)
Expand Down Expand Up @@ -356,6 +357,12 @@ end
@inline primitive_type(::Type{Float16}) = 10
@inline primitive_type(::Type{Float32}) = 11

@inline primitive_type(::Type{Reactant.F8E5M2}) = 19
@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20
@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23
@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24
@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25

@static if isdefined(Core, :BFloat16)
@inline primitive_type(::Type{Core.BFloat16}) = 16
end
Expand Down
1 change: 1 addition & 0 deletions src/mlir/IR/IR.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module IR

using ..Reactant
using ..API

# do not export `Type`, as it is already defined in Core
Expand Down
76 changes: 76 additions & 0 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,51 @@ Creates a f64 type in the given context. The type is owned by the context.
"""
Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet(context))

"""
Type(::Core.Type{Reactant.F8E5M2}; context=context())
Creates a f8e5m2 type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E5M2}; context::Context=context())
return Type(API.mlirFloat8E5M2TypeGet(context))
end

"""
Type(::Core.Type{Reactant.F8E4M3FN}; context=context())
Creates a f8e4m3fn type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3FN}; context::Context=context())
return Type(API.mlirFloat8E4M3FNTypeGet(context))
end

"""
Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context=context())
Creates a f8e4m3b11fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3B11FNUZ}; context::Context=context())
return Type(API.mlirFloat8E4M3B11FNUZTypeGet(context))
end

"""
Type(::Core.Type{Reactant.F8E5M2FNUZ}; context=context())
Creates a f8e5m2fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E5M2FNUZ}; context::Context=context())
return Type(API.mlirFloat8E5M2FNUZTypeGet(context))
end

"""
Type(::Core.Type{Reactant.F8E4M3FNUZ}; context=context())
Creates a f8e4m3fnuz type in the given context. The type is owned by the context.
"""
function Type(::Core.Type{Reactant.F8E4M3FNUZ}; context::Context=context())
return Type(API.mlirFloat8E4M3FNTypeGet(context))
end

"""
isf8e5m2(type)
Expand All @@ -205,6 +250,27 @@ Checks whether the given type is an f8E4M3FN type.
"""
isf8e4m3fn(type::Type) = API.mlirTypeIsAFloat8E4M3FN(type)

"""
isf8e4m3b11fnuz(type)
Checks whether the given type is an f8E4M3B11FNUZ type.
"""
isf8e4m3b11fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3B11FNUZ(type)

"""
isf8e5m2fnuz(type)
Checks whether the given type is an f8E5M2FNUZ type.
"""
isf8e5m2fnuz(type::Type) = API.mlirTypeIsAFloat8E5M2FNUZ(type)

"""
isf8e4m3fnuz(type)
Checks whether the given type is an f8E4M3FNUZ type.
"""
isf8e4m3fnuz(type::Type) = API.mlirTypeIsAFloat8E4M3FNUZ(type)

"""
isbf16(type)
Expand Down Expand Up @@ -738,6 +804,16 @@ function julia_type(type::Type)
Float32
elseif isf64(type)
Float64
elseif isf8e5m2(type)
Reactant.F8E5M2
elseif isf8e4m3fn(type)
Reactant.F8E4M3FN
elseif isf8e4m3b11fnuz(type)
Reactant.F8E4M3B11FNUZ
elseif isf8e5m2fnuz(type)
Reactant.F8E5M2FNUZ
elseif isf8e4m3fnuz(type)
Reactant.F8E4M3FNUZ
elseif isnone(type)
Nothing
elseif iscomplex(type)
Expand Down
2 changes: 2 additions & 0 deletions src/mlir/MLIR.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module MLIR

using ..Reactant

module API
using CEnum
using Preferences
Expand Down

0 comments on commit d68cda9

Please sign in to comment.