Skip to content

Commit

Permalink
fix: missing scalar indexing check for setindex (#491)
Browse files Browse the repository at this point in the history
* fix: missing scalar indexing check for setindex

* fix: missing copyto!

* Update src/TracedRArray.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix: overload Enzyme.onehot to avoid scalar indexing

* fix: mark tests with allowscalar

* chore: formatting

---------

Co-authored-by: William Moses <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 12531c9 commit 46b8c14
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TODO: move the overload_autodiff here as well

# The default `onehot` will lead to scalar indexing
function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N}
x_arr = zeros(T, size(x))
return map(Base.Fix1(TracedUtils.promote_to, TracedRArray{T,N}), Enzyme.onehot(x_arr))
end
3 changes: 3 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ end
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")

# Other Integrations
include("Enzyme.jl")

const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}

include("ControlFlow.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,17 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
end

function maybe_assert_scalar_setindexing(
::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
return GPUArraysCore.assertscalar("setindex!(::TracedRArray, v, ::Vararg{Int, N})")
end

maybe_assert_scalar_setindexing(args...) = nothing

function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
maybe_assert_scalar_setindexing(a, indices...)

indices = map(enumerate(indices)) do (idx, i)
i isa Colon && return 1:size(a, idx)
i isa CartesianIndex && return Tuple(i)
Expand Down Expand Up @@ -473,6 +483,10 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T,
return dest
end

function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T2,N}) where {T,T2,N}
return copyto!(dest, Ops.convert(TracedRArray{T,N}, src))
end

function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
Expand Down
6 changes: 3 additions & 3 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ function condition10_condition_with_setindex(x)
@trace if sum(x) > 0
x[:, 1] = -1.0
else
x[1, 1] = 1.0
@allowscalar x[1, 1] = 1.0
end
return x
end
Expand Down Expand Up @@ -457,7 +457,7 @@ end

function for_with_step(x)
@trace for i in 10:3:22
x[i] = i * i
@allowscalar x[i] = i * i
end
return x
end
Expand Down Expand Up @@ -539,7 +539,7 @@ function cumsum!(x)
v = zero(eltype(x))
@trace for i in 1:length(x)
v += @allowscalar x[i]
x[i] = v
@allowscalar x[i] = v
end
return x
end
Expand Down

0 comments on commit 46b8c14

Please sign in to comment.