Skip to content

Commit

Permalink
feat: introduce is_reactant_primitive to define additional primitives
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 22, 2025
1 parent 1c42a58 commit 89ef696
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 42 deletions.
4 changes: 2 additions & 2 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(T)
return Reactant.ConcreteRArray{T,N}
else
TT = Reactant.traced_type_inner(T, seen, mode, track_numbers)
Expand All @@ -795,7 +795,7 @@ function Reactant.make_tracer(
if haskey(seen, prev)
return seen[prev]
end
if mode == Reactant.ArrayToConcrete && eltype(RT) <: Reactant.ReactantPrimitive
if mode == Reactant.ArrayToConcrete && Reactant.is_reactant_primitive(eltype(RT))
return seen[prev] = Reactant.ConcreteRArray(Array(prev))
end
TT = Reactant.traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down
6 changes: 3 additions & 3 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dims::Dims
) where {T}
if T <: ReactantPrimitive
if is_reactant_primitive(T)
return TracedRandom.$(overload_randfun)(rng, T, dims)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand All @@ -70,7 +70,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...
) where {T}
if T <: ReactantPrimitive
if is_reactant_primitive(T)
return TracedRandom.$(overload_randfun)(rng, T, dim1, dims...)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand All @@ -82,7 +82,7 @@ for randfun in (:rand, :randn, :randexp)
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
) where {T}
if T <: ReactantPrimitive
if is_reactant_primitive(T)
return TracedRandom.$(overload_randfun)(rng, T)
end
@warn "Reactant doesn't support sampling of $(T) with the current \
Expand Down
22 changes: 21 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,15 @@ const ReactantPrimitive = Union{
Base.uniontypes(ReactantComplexFloat)...,
}

abstract type RNumber{T<:ReactantPrimitive} <: Number end
"""
is_reactant_primitive(::Type{T})
Returns `true` if `T` is a primitive type supported by Reactant.
"""
is_reactant_primitive(::Type{<:ReactantPrimitive}) = true
is_reactant_primitive(::Type) = false

abstract type RNumber{T} <: Number end

abstract type RArray{T,N} <: AbstractArray{T,N} end

Expand Down Expand Up @@ -97,6 +105,7 @@ mutable struct TracedRNumber{T} <: RNumber{T}
function TracedRNumber{T}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}
) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == ()
end
Expand All @@ -114,6 +123,7 @@ mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N}
function TracedRArray{T,N}(
paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
shape = Tuple(shape)
if !isnothing(mlir_data)
@assert size(MLIR.IR.type(mlir_data)) == shape "Expected: $(shape), got: $(size(MLIR.IR.type(mlir_data)))"
Expand Down Expand Up @@ -157,13 +167,23 @@ Adapt.parent_type(::Type{XLAArray{T,N}}) where {T,N} = XLAArray{T,N}

mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer

function ConcreteRNumber{T}(data::XLA.AsyncBuffer) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T}(data)
end
end

@leaf ConcreteRNumber

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
shape::NTuple{N,Int}

function ConcreteRArray{T,N}(data::XLA.AsyncBuffer, shape::NTuple{N,Int}) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T,N}(data, shape)
end
end

@leaf ConcreteRArray
Expand Down
12 changes: 7 additions & 5 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
ReactantPrimitive,
WrappedTracedRArray,
AnyTracedRArray,
AnyTracedRVector,
Expand All @@ -16,7 +15,8 @@ using ..Reactant:
ancestor,
allowscalar,
aos_to_soa,
unwrapped_eltype
unwrapped_eltype,
is_reactant_primitive
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

using ReactantCore: ReactantCore
Expand Down Expand Up @@ -480,14 +480,16 @@ end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
) where {T<:ReactantPrimitive,N}
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end

function Base.similar(
::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{TracedRNumber{T}}, dims
) where {T<:ReactantPrimitive,N}
) where {T,N}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
@assert N isa Int
return TracedRArray{T,length(dims)}((), nothing, map(length, dims))
end
Expand All @@ -507,7 +509,7 @@ end
# we need to override the outer copy method to make sure we never fall back to scalar
# iteration (see, e.g., CUDA.jl#145)
function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle})
fn = if bc.f isa Type && bc.f <: ReactantPrimitive
fn = if bc.f isa Type && is_reactant_primitive(bc.f)
TracedUtils.TypeCast{bc.f}()
else
bc.f
Expand Down
9 changes: 1 addition & 8 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
module TracedRNumberOverrides

using ..Reactant:
Reactant,
TracedRNumber,
TracedRArray,
ReactantPrimitive,
TracedUtils,
Ops,
MLIR,
unwrapped_eltype
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
using ReactantCore

ReactantCore.is_traced(::TracedRNumber) = true
Expand Down
16 changes: 11 additions & 5 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ using ..Reactant:
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict,
ReactantPrimitive,
Ops
Ops,
is_reactant_primitive
using ReactantCore: MissingTracedValue

materialize_traced_array(x::TracedRArray) = x
Expand Down Expand Up @@ -283,15 +283,21 @@ function __take_region(compiled_fn)
return region
end

elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x
elem_apply(::Type{T}, x::TracedRArray{T}) where {T} = x

struct TypeCast{T<:ReactantPrimitive} <: Function end
struct TypeCast{T} <: Function
function TypeCast{T}() where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
return new{T}()
end
end

function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
return TracedUtils.promote_to(TracedRNumber{T}, x)
end

function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive}
function elem_apply(::Type{T}, x::TracedRArray) where {T}
@assert is_reactant_primitive(T) "$(T) is not a primitive type supported by Reactant."
# Special Path to prevent going down a despecialized path
return elem_apply(TypeCast{T}(), x)
end
Expand Down
39 changes: 21 additions & 18 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@ Base.@nospecializeinfer function traced_type_inner(

subParms = []
for (i, SST) in enumerate(T.parameters)
if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive
if wrapped_carray && i == 1 && SST isa Type && is_reactant_primitive(SST)
TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers)
push!(subParms, TrT)
elseif wrapped_tracedarray &&
i == 1 &&
SST isa Type &&
SST <: TracedRNumber{<:ReactantPrimitive}
elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber
TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers)
push!(subParms, TrT)
else
Expand Down Expand Up @@ -164,15 +161,17 @@ for T in (
end

Base.@nospecializeinfer function traced_type_inner(
@nospecialize(T::Type{<:ReactantPrimitive}),
@nospecialize(T::Type{<:Number}),
seen,
@nospecialize(mode::TraceMode),
@nospecialize(track_numbers::Type)
)
if Mode == ArrayToConcrete && T <: track_numbers
return ConcreteRNumber{T}
elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers
return TracedRNumber{T}
if is_reactant_primitive(T)
if Mode == ArrayToConcrete && T <: track_numbers
return ConcreteRNumber{T}
elseif (mode == NoStopTracedTrack || mode == TracedTrack) && T <: track_numbers
return TracedRNumber{T}
end
end
return T
end
Expand Down Expand Up @@ -395,7 +394,7 @@ Base.@nospecializeinfer function traced_type_inner(
)
T = eltype(A)
N = ndims(A)
if mode == ArrayToConcrete && T <: ReactantPrimitive
if mode == ArrayToConcrete && is_reactant_primitive(T)
return ConcreteRArray{T,N}
else
return Array{traced_type_inner(T, seen, mode, track_numbers),N}
Expand Down Expand Up @@ -904,7 +903,7 @@ function make_tracer(
if mode != NoStopTracedTrack && haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
if mode == ArrayToConcrete && is_reactant_primitive(eltype(RT))
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), Val(mode), track_numbers)
Expand Down Expand Up @@ -995,17 +994,21 @@ end
@nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type)
) = x
@inline function to_rarray_internal(
@nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type)
)
return ConcreteRArray(x)
@nospecialize(x::Array{T}), @nospecialize(track_numbers::Type)
) where {T}
is_reactant_primitive(T) && return ConcreteRArray(x)
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end

@inline to_rarray_internal(
@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)
) = x
@inline function to_rarray_internal(
@nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type)
@nospecialize(x::Number), @nospecialize(track_numbers::Type)
)
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
if is_reactant_primitive(typeof(x))
typeof(x) <: track_numbers && return ConcreteRNumber(x)
return x
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type)
end

0 comments on commit 89ef696

Please sign in to comment.