Skip to content

Commit

Permalink
feat: implement argmin and argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 14, 2025
1 parent 5b23ead commit a618257
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T}
return ConcreteRNumber(Base.rtoldefault(T))
end

Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...)

# Ensure the device and client are the same as the input
function Base.float(x::ConcreteRNumber{T}) where {T}
client = XLA.client(x.data)
Expand Down
6 changes: 5 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1024,9 +1024,13 @@ end
values = mlir_type(TracedRArray{T,N}, rsize)
indices = mlir_type(TracedRArray{Int32,N}, rsize)
op = chlo.top_k(x.mlir_data; values, indices, k, location)
indices = add(
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
constant(fill(Int32(1), Tuple(rsize))),
) # return the 1-indexed index
return (;
values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize),
indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
indices,
)
end

Expand Down
49 changes: 46 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ using GPUArraysCore: GPUArraysCore, @allowscalar

ReactantCore.is_traced(::TracedRArray) = true

Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)

function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
@assert ndims(x) == N
if x isa TracedRArray
Expand Down Expand Up @@ -510,7 +512,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
res = TracedUtils.promote_to(
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
)
TracedUtils.set_mlir_data!(dest, res.mlir_data)
return dest
end
Expand Down Expand Up @@ -743,15 +748,53 @@ function Base.partialsort!(
!rev && (k = length(x) - k + 1)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), k)
by === identity && return @allowscalar values[k]
return @allowscalar x[indices[k] + 1]
return @allowscalar x[indices[k]]
else
klist = collect(Int64, k)
!rev && (klist = length(x) .- klist .+ 1)
maxk = maximum(klist)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk)
by === identity && return values[klist]
return x[indices[klist] .+ 1]
return x[indices[klist]]
end
end

function Base.argmin(x::AnyTracedRArray; kwargs...)
return argmax(Ops.negate(materialize_traced_array(x)); kwargs...)
end

function Base.argmax(x::AnyTracedRVector)
(; indices) = Ops.top_k(materialize_traced_array(x), 1)
return @allowscalar indices[1]
end

# To avoid scalar indexing and constructing an array of tuples, we return the linear index
# instead of the cartesian index
function Base.argmax(x::AnyTracedRArray{T,N}; dims::Integer) where {T,N}
strds = strides(x)

if dims != N # chlo.top_k performs the operation along the last dimension
pdims = collect(Int64, 1:N)
pdims[dims] = N
pdims[N] = dims
pdims = Tuple(pdims)
x = permutedims(x, pdims)
end
(; indices) = Ops.top_k(materialize_traced_array(x), 1)
indices = Ops.convert(TracedRArray{Int64,N}, indices)
dims != N && (indices = permutedims(indices, invperm(pdims)))

# Compute linear indices
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:N]
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
for d in 1:N
linear_indices = Ops.add(
linear_indices,
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
)
end
return linear_indices
end

end

0 comments on commit a618257

Please sign in to comment.