Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support lowering custom fp types #596

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/PrimitiveTypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of
# types, users can use these to convert their types to the primitive types supported by
# Reactant.
struct F8E5M2{T} <: AbstractFloat
val::T
end

struct F8E4M3FN{T} <: AbstractFloat
val::T
end

struct F8E4M3B11FNUZ{T} <: AbstractFloat
val::T
end

struct F8E5M2FNUZ{T} <: AbstractFloat
val::T
end

struct F8E4M3FNUZ{T} <: AbstractFloat
val::T
end

# TODO: Quantized types

const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ}

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{
Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)...
}
else
const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...}
end

const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...}

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}

"""
to_reactant_primitive(val)

Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a
function that can be used to convert the value back to the original type.
"""
to_reactant_primitive(::T) where {T} = nothing, nothing

for T in Base.uniontypes(ReactantPrimitive)
@eval to_reactant_primitive(val::$T) = val, identity
end
40 changes: 1 addition & 39 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,7 @@ using Enzyme

struct ReactantABI <: Enzyme.EnzymeCore.ABI end

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64}
else
const ReactantFloat = Union{Float16,Float32,Float64}
end

@static if isdefined(Core, :BFloat16)
const ReactantComplexFloat = Union{
Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}
}
else
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}
include("PrimitiveTypes.jl")

abstract type RNumber{T<:ReactantPrimitive} <: Number end

Expand Down
7 changes: 3 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
ReactantPrimitive,
WrappedTracedRArray,
AnyTracedRArray,
AnyTracedRVector,
Expand Down Expand Up @@ -480,14 +479,14 @@ end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
) where {T<:ReactantPrimitive,N}
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end
Expand All @@ -507,7 +506,7 @@ end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
fn = if bc.f isa Type && bc.f <: ReactantPrimitive
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
TracedUtils.TypeCast{bc.f}()
else
bc.f
Expand Down
9 changes: 1 addition & 8 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
module TracedRNumberOverrides

using ..Reactant:
Reactant,
TracedRNumber,
TracedRArray,
ReactantPrimitive,
TracedUtils,
Ops,
MLIR,
unwrapped_eltype
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
using ReactantCore

ReactantCore.is_traced(::TracedRNumber) = true
Expand Down
8 changes: 3 additions & 5 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using ..Reactant:
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict,
ReactantPrimitive,
Ops
using ReactantCore: MissingTracedValue

Expand Down Expand Up @@ -283,16 +282,15 @@ function __take_region(compiled_fn)
return region
end

elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x

struct TypeCast{T<:ReactantPrimitive} <: Function end
struct TypeCast{T <: Reactant.ReactantPrimitive} <: Function end
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
# Special Path to prevent going down a despecialized path
function elem_apply(::Type{T}, x::TracedRArray) where {T <: Reactant.ReactantPrimitive}
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return elem_apply(TypeCast{T}(), x)
end

Expand Down
9 changes: 3 additions & 6 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ Base.@nospecializeinfer function traced_type_inner(
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray &&
i == 1 &&
SST isa Type &&
SST <: TracedRNumber{<:ReactantPrimitive}
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers)
push!(subParms, TrT)
else
Expand Down Expand Up @@ -395,7 +392,7 @@ Base.@nospecializeinfer function traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == ArrayToConcrete && T <: ReactantPrimitive
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T,N}
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
Expand Down Expand Up @@ -904,7 +901,7 @@ function make_tracer(
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down
Loading