Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lowering custom fp types #596

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ end
return TracedRArray{T,N}((), res, size(x))
end

## This is somewhat a hack because I can't seem to find the corresponding mlir
## DenseElementsAttribute functions (also our optimizations will run a pass converting this
## to a single operation)
for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ)
@eval @noinline function constant(
x::DenseArray{<:Reactant.$(T),N};
location=mlir_stacktrace("constant", @__FILE__, @__LINE__),
) where {N}
value = MLIR.IR.DenseElementsAttribute(
map(Float16 ∘ Base.Fix2(getproperty, :val), x)
)
output = mlir_type(TracedRArray{Float16,N}, size(x))
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
return convert(
TracedRArray{eltype(x),N}, TracedRArray{Float16,N}((), res, size(x)); location
)
end
end

@noinline function constant(
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
) where {T<:Number}
Expand Down
66 changes: 66 additions & 0 deletions src/PrimitiveTypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# The types listed in this file are the ones present in StableHLO specification.

# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of
# types, users can use these to convert their types to the primitive types supported by
# Reactant.
for T in (:F8E5M2, :F8E4M3FN, :F8E4M3B11FNUZ, :F8E5M2FNUZ, :F8E4M3FNUZ)
@eval begin
struct $(T){inT} <: AbstractFloat
val::inT
end

Base.promote_rule(::Type{<:$(T)}, ::Type{Float16}) = Float16
Base.promote_rule(::Type{Float16}, ::Type{<:$(T)}) = Float16

Base.promote_rule(::Type{<:$(T)}, ::Type{Float32}) = Float32
Base.promote_rule(::Type{Float32}, ::Type{<:$(T)}) = Float32

Base.promote_rule(::Type{<:$(T)}, ::Type{Float64}) = Float64
Base.promote_rule(::Type{Float64}, ::Type{<:$(T)}) = Float64

Base.promote_rule(::Type{<:$(T){inT}}, ::Type{<:Integer}) where {inT} = $(T){inT}
Base.promote_rule(::Type{<:Integer}, ::Type{<:$(T){inT}}) where {inT} = $(T){inT}

@static if isdefined(Core, :BFloat16)
Base.promote_rule(::Type{<:$(T)}, ::Type{Core.BFloat16}) = Core.BFloat16
Base.promote_rule(::Type{Core.BFloat16}, ::Type{<:$(T)}) = Core.BFloat16
end
end
end

const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ}

# TODO: Quantized types

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{
Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)...
}
else
const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...}
end

const ReactantComplexFloat = Union{Complex{Float32},Complex{Float64}}

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64}

const ReactantComplexInt = Union{[Complex{T} for T in Base.uniontypes(ReactantInt)]...}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}

to_reactant_primitive(v::T) where {T} = reactant_primitive(T)(v)
reactant_primitive(::Type{T}) where {T} = nothing

for T in Base.uniontypes(ReactantPrimitive)
@eval to_reactant_primitive(val::$T) = val
@eval reactant_primitive(::Type{$T}) = $T
end
40 changes: 1 addition & 39 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,7 @@ using Enzyme

struct ReactantABI <: Enzyme.EnzymeCore.ABI end

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64}
else
const ReactantFloat = Union{Float16,Float32,Float64}
end

@static if isdefined(Core, :BFloat16)
const ReactantComplexFloat = Union{
Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}
}
else
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}
include("PrimitiveTypes.jl")

abstract type RNumber{T<:ReactantPrimitive} <: Number end

Expand Down
7 changes: 3 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
ReactantPrimitive,
WrappedTracedRArray,
AnyTracedRArray,
AnyTracedRVector,
Expand Down Expand Up @@ -480,14 +479,14 @@ end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
) where {T<:ReactantPrimitive,N}
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end
Expand All @@ -507,7 +506,7 @@ end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
fn = if bc.f isa Type && bc.f <: ReactantPrimitive
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
TracedUtils.TypeCast{bc.f}()
else
bc.f
Expand Down
25 changes: 10 additions & 15 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
module TracedRNumberOverrides

using ..Reactant:
Reactant,
TracedRNumber,
TracedRArray,
ReactantPrimitive,
TracedUtils,
Ops,
MLIR,
unwrapped_eltype
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
using ReactantCore

ReactantCore.is_traced(::TracedRNumber) = true
Expand All @@ -19,7 +12,7 @@ Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T
Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T))
Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ())

function Base.eps(::Type{TracedRNumber{T}}) where {T}
function Base.eps(::Type{<:TracedRNumber{T}}) where {T}
return TracedUtils.promote_to(TracedRNumber{T}, eps(T))
end

Expand All @@ -43,24 +36,26 @@ end

Base.only(A::TracedRNumber{T}) where {T} = A

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S}
function Base.promote_rule(
::Type{<:TracedRNumber{T}}, ::Type{<:TracedRNumber{S}}
) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

# Bool has special promotion rules in Base
function Base.promote_rule(::Type{Bool}, ::Type{TracedRNumber{T}}) where {T}
function Base.promote_rule(::Type{Bool}, ::Type{<:TracedRNumber{T}}) where {T}
return TracedRNumber{T}
end

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{Bool}) where {T}
function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{Bool}) where {T}
return TracedRNumber{T}
end

function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
function Base.promote_rule(::Type{T}, ::Type{<:TracedRNumber{S}}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{S}) where {T,S}
function Base.promote_rule(::Type{<:TracedRNumber{T}}, ::Type{S}) where {T,S}
return TracedRNumber{Base.promote_type(T, S)}
end

Expand All @@ -74,7 +69,7 @@ function TracedRNumber{T}(x::Number) where {T}
return TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x)
end

function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
function TracedUtils.promote_to(::Type{<:TracedRNumber{T}}, rhs) where {T}
if rhs isa TracedRNumber
rhs isa TracedRNumber{T} && return rhs
return Ops.convert(TracedRNumber{T}, rhs)
Expand Down
8 changes: 3 additions & 5 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using ..Reactant:
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict,
ReactantPrimitive,
Ops
using ReactantCore: MissingTracedValue

Expand Down Expand Up @@ -283,16 +282,15 @@ function __take_region(compiled_fn)
return region
end

elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x

struct TypeCast{T<:ReactantPrimitive} <: Function end
struct TypeCast{T <: Reactant.ReactantPrimitive} <: Function end
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
# Special Path to prevent going down a despecialized path
function elem_apply(::Type{T}, x::TracedRArray) where {T <: Reactant.ReactantPrimitive}
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return elem_apply(TypeCast{T}(), x)
end

Expand Down
27 changes: 20 additions & 7 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ Base.@nospecializeinfer function traced_type_inner(
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray &&
i == 1 &&
SST isa Type &&
SST <: TracedRNumber{<:ReactantPrimitive}
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers)
push!(subParms, TrT)
else
Expand Down Expand Up @@ -395,7 +392,7 @@ Base.@nospecializeinfer function traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == ArrayToConcrete && T <: ReactantPrimitive
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T,N}
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
Expand Down Expand Up @@ -534,7 +531,7 @@ Base.@assume_effects :total @inline function traced_type(
cache = Dict{Type,Type}()
traced_type_cache[cache_key] = cache
end
return res1 = traced_type_inner(T, cache, mode, track_numbers)
return traced_type_inner(T, cache, mode, track_numbers)
end

abstract type TracedTypeException <: Exception end
Expand Down Expand Up @@ -904,7 +901,7 @@ function make_tracer(
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down Expand Up @@ -999,6 +996,14 @@ end
)
return ConcreteRArray(x)
end
@inline function to_rarray_internal(
@nospecialize(x::Array{T}), @nospecialize(track_numbers::Type)
) where {T<:Number}
if reactant_primitive(T) !== nothing
return ConcreteRArray(to_reactant_primitive.(x))
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end

@inline to_rarray_internal(
@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)
Expand All @@ -1009,3 +1014,11 @@ end
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
end
@inline function to_rarray_internal(
@nospecialize(x::Number), @nospecialize(track_numbers::Type)
)
if reactant_primitive(typeof(x)) !== nothing
return ConcreteRArray(to_reactant_primitive(x))
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end
7 changes: 7 additions & 0 deletions src/XLA.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module XLA

import ..Reactant
import ...MLIR

const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75)
Expand Down Expand Up @@ -356,6 +357,12 @@ end
@inline primitive_type(::Type{Float16}) = 10
@inline primitive_type(::Type{Float32}) = 11

@inline primitive_type(::Type{<:Reactant.F8E5M2}) = 19
@inline primitive_type(::Type{<:Reactant.F8E4M3FN}) = 20
@inline primitive_type(::Type{<:Reactant.F8E4M3B11FNUZ}) = 23
@inline primitive_type(::Type{<:Reactant.F8E5M2FNUZ}) = 24
@inline primitive_type(::Type{<:Reactant.F8E4M3FNUZ}) = 25

@static if isdefined(Core, :BFloat16)
@inline primitive_type(::Type{Core.BFloat16}) = 16
end
Expand Down
1 change: 1 addition & 0 deletions src/mlir/IR/IR.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module IR

using ..Reactant
using ..API

# do not export `Type`, as it is already defined in Core
Expand Down
Loading
Loading