Skip to content

Commit

Permalink
fix: general support for other kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 14, 2025
1 parent a618257 commit 2172ff2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
5 changes: 1 addition & 4 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1028,10 +1028,7 @@ end
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
constant(fill(Int32(1), Tuple(rsize))),
) # return the 1-indexed index
return (;
values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize),
indices,
)
return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), indices)
end

@noinline function iota(
Expand Down
68 changes: 60 additions & 8 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -694,23 +694,31 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
end

# sort
Base.sort(x::AnyTracedRArray; kwargs...) = sort!(copy(x); kwargs...)
function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
return sort!(copy(x); alg, order, kwargs...)
end

function Base.sort!(
x::AnyTracedRArray;
dims::Integer,
lt=isless,
by=identity,
rev::Bool=false,
kwargs..., # TODO: implement `order` and `alg` kwargs
alg=missing,
order=missing,
)
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"

comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b))
res = Ops.sort(materialize_traced_array(x); dimension=dims, comparator)
set_mlir_data!(x, get_mlir_data(res))
return x
end

Base.sortperm(x::AnyTracedRArray; kwargs...) = sortperm!(similar(x, Int), x; kwargs...)
function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
return sortperm!(similar(x, Int), x; alg, order, kwargs...)
end

function Base.sortperm!(
ix::AnyTracedRArray{Int,N},
Expand All @@ -719,8 +727,12 @@ function Base.sortperm!(
lt=isless,
by=identity,
rev::Bool=false,
kwargs..., # TODO: implement `order` and `alg` kwargs
alg=missing,
order=missing,
) where {N}
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"

comparator =
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
idxs = Ops.constant(collect(LinearIndices(x)))
Expand All @@ -743,19 +755,59 @@ function Base.partialsort!(
# TODO: general `lt` support
@assert lt === isless "Only `isless` is supported for now in `partialsort!`"

# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
by_x = by.(x)
if k isa Integer
!rev && (k = length(x) - k + 1)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), k)
res = by === identity ? @allowscalar(values[k]) : @allowscalar(x[indices[k]])
@allowscalar setindex!(ix, res, k)
return res
else
klist = collect(Int64, k)
!rev && (klist = length(x) .- klist .+ 1)
maxk = maximum(klist)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk)
res = by === identity ? values[klist] : x[indices[klist]]
setindex!(ix, res, klist)
return res
end
end

function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
)
return partialsortperm!(similar(x, Int), x, k; kwargs...)
end

function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
by=identity,
rev::Bool=false,
lt=isless,
)
# TODO: general `lt` support
@assert lt === isless "Only `isless` is supported for now in `partialsortperm!`"

by_x = by.(x)
# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
if k isa Integer
!rev && (k = length(x) - k + 1)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), k)
by === identity && return @allowscalar values[k]
return @allowscalar x[indices[k]]
indices = Ops.convert(TracedRArray{Int64,1}, indices)
idx = @allowscalar indices[k]
@allowscalar setindex!(ix, idx, k)
return idx
else
klist = collect(Int64, k)
!rev && (klist = length(x) .- klist .+ 1)
maxk = maximum(klist)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk)
by === identity && return values[klist]
return x[indices[klist]]
indices = Ops.convert(TracedRArray{Int64,1}, indices)
setindex!(ix, indices[klist], klist)
return indices[klist]
end
end

Expand Down

0 comments on commit 2172ff2

Please sign in to comment.