Skip to content

Commit

Permalink
refactor: move to a different api
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 22, 2025
1 parent 89ef696 commit 1a873d0
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 103 deletions.
4 changes: 2 additions & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T)
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
return Reactant.ConcreteRArray{T,N}
else
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers)
Expand All @@ -795,7 +795,7 @@ function Reactant.make_tracer(
if haskey(seen, prev)
return seen[prev]
end
if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT))
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
return seen[prev] = Reactant.ConcreteRArray(Array(prev))
end
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down
6 changes: 3 additions & 3 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
if is_reactant_primitive(T)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand All @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
if is_reactant_primitive(T)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand All @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
if is_reactant_primitive(T)
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand Down
74 changes: 74 additions & 0 deletions src/PrimitiveTypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# These only exist for the purpose of lowering. Since `ReactantPrimitive` is a fixed set of
# types, users can use these to convert their types to the primitive types supported by
# Reactant.
struct F8E5M2{T} <: AbstractFloat
val::T
end

struct F8E4M3FN{T} <: AbstractFloat
val::T
end

struct F8E4M3B11FNUZ{T} <: AbstractFloat
val::T
end

struct F8E5M2FNUZ{T} <: AbstractFloat
val::T
end

struct F8E4M3FNUZ{T} <: AbstractFloat
val::T
end

# TODO: Quantized types

const ReactantFloat8 = Union{F8E5M2,F8E4M3FN,F8E4M3B11FNUZ,F8E5M2FNUZ,F8E4M3FNUZ}

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{
Float16,Core.BFloat16,Float32,Float64,Base.uniontypes(ReactantFloat8)...
}
else
const ReactantFloat = Union{Float16,Float32,Float64,Base.uniontypes(ReactantFloat8)...}
end

const ReactantComplexFloat = Union{[Complex{T} for T in Base.uniontypes(ReactantFloat)]...}

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}

"""
to_reactant_primitive(val)
Constructs a Reactant primitive from the given value. Returns the Reactant primitive and a
function that can be used to convert the value back to the original type.
"""
to_reactant_primitive(::T) where {T} = nothing, nothing

for T in Base.uniontypes(ReactantPrimitive)
@eval to_reactant_primitive(val::$T) = val, identity
end
62 changes: 2 additions & 60 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,55 +18,9 @@ using Enzyme

struct ReactantABI <: Enzyme.EnzymeCore.ABI end

@static if isdefined(Core, :BFloat16)
const ReactantFloat = Union{Float16,Core.BFloat16,Float32,Float64}
else
const ReactantFloat = Union{Float16,Float32,Float64}
end

@static if isdefined(Core, :BFloat16)
const ReactantComplexFloat = Union{
Complex{Float16},Complex{Core.BFloat16},Complex{Float32},Complex{Float64}
}
else
const ReactantComplexFloat = Union{Complex{Float16},Complex{Float32},Complex{Float64}}
end

const ReactantInt = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}

const ReactantComplexInt = Union{
Complex{Int8},
Complex{UInt8},
Complex{Int16},
Complex{UInt16},
Complex{Int32},
Complex{UInt32},
Complex{Int64},
Complex{UInt64},
Complex{Int128},
Complex{UInt128},
}

const ReactantFloatInt = Union{
Base.uniontypes(ReactantInt)...,Base.uniontypes(ReactantFloat)...
}
include("PrimitiveTypes.jl")

const ReactantPrimitive = Union{
Bool,
Base.uniontypes(ReactantFloatInt)...,
Base.uniontypes(ReactantComplexInt)...,
Base.uniontypes(ReactantComplexFloat)...,
}

"""
is_reactant_primitive(::Type{T})
Returns `true` if `T` is a primitive type supported by Reactant.
"""
is_reactant_primitive(::Type{<:ReactantPrimitive}) = true
is_reactant_primitive(::Type) = false

abstract type RNumber{T} <: Number end
abstract type RNumber{T<:ReactantPrimitive} <: Number end

abstract type RArray{T,N} <: AbstractArray{T,N} end

Expand Down Expand Up @@ -105,7 +59,6 @@ mutable struct TracedRNumber{T} <: RNumber{T}
function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
Expand All @@ -123,7 +76,6 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
function TracedRArray{T,N}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
shape = Tuple(shape)
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
Expand Down Expand Up @@ -167,23 +119,13 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer

function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T}(data)
end
end

@leaf ConcreteRNumber

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
shape::NTuple{N,Int}

function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T,N}(data, shape)
end
end

@leaf ConcreteRArray
Expand Down
11 changes: 4 additions & 7 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ using ..Reactant:
ancestor,
allowscalar,
aos_to_soa,
unwrapped_eltype,
is_reactant_primitive
unwrapped_eltype
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

using ReactantCore: ReactantCore
Expand Down Expand Up @@ -480,16 +479,14 @@ end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
) where {T<:Reactant.ReactantPrimitive,N}
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end
Expand All @@ -509,7 +506,7 @@ end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
fn = if bc.f isa Type && is_reactant_primitive(bc.f)
fn = if bc.f isa Type && bc.f <: Reactant.ReactantPrimitive
TracedUtils.TypeCast{bc.f}()
else
bc.f
Expand Down
14 changes: 3 additions & 11 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ using ..Reactant:
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict,
Ops,
is_reactant_primitive
Ops
using ReactantCore: MissingTracedValue

materialize_traced_array(x::TracedRArray) = x
Expand Down Expand Up @@ -285,20 +284,13 @@ end

elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x

struct TypeCast{T} <: Function
function TypeCast{T}() where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T}()
end
end
struct TypeCast{T <: Reactant.ReactantPrimitive} <: Function end

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
# Special Path to prevent going down a despecialized path
function elem_apply(::Type{T}, x::TracedRArray) where {T <: Reactant.ReactantPrimitive}
return elem_apply(TypeCast{T}(), x)
end

Expand Down
34 changes: 14 additions & 20 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Base.@nospecializeinfer function traced_type_inner(

subParms = []
for (i, SST) in enumerate(T.parameters)
if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST)
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
Expand Down Expand Up @@ -161,17 +161,15 @@ for T in (
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:Number}),
@nospecialize(T::Type{<:ReactantPrimitive}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type)
)
if is_reactant_primitive(T)
if Mode == ArrayToConcrete && T <: track_numbers
return ConcreteRNumber{T}
elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers
return TracedRNumber{T}
end
if Mode == ArrayToConcrete && T <: track_numbers
return ConcreteRNumber{T}
elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers
return TracedRNumber{T}
end
return T
end
Expand Down Expand Up @@ -394,7 +392,7 @@ Base.@nospecializeinfer function traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == ArrayToConcrete && is_reactant_primitive(T)
if mode == ArrayToConcrete && T <: Reactant.ReactantPrimitive
return ConcreteRArray{T,N}
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
Expand Down Expand Up @@ -903,7 +901,7 @@ function make_tracer(
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT))
if mode == ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down Expand Up @@ -994,21 +992,17 @@ end
@nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type)
) = x
@inline function to_rarray_internal(
@nospecialize(x::Array{T}), @nospecialize(track_numbers::Type)
) where {T}
is_reactant_primitive(T) && return ConcreteRArray(x)
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
@nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type)
)
return ConcreteRArray(x)
end

@inline to_rarray_internal(
@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)
) = x
@inline function to_rarray_internal(
@nospecialize(x::Number), @nospecialize(track_numbers::Type)
@nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type)
)
if is_reactant_primitive(typeof(x))
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
end

0 comments on commit 1a873d0

Please sign in to comment.