From 2172ff25c8b17bc51ee3e852fb384f7f9fd858a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 15:01:17 -0500 Subject: [PATCH] fix: general support for other kwargs --- src/Ops.jl | 5 +--- src/TracedRArray.jl | 68 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 287350179..06108b15f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1028,10 +1028,7 @@ end TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), constant(fill(Int32(1), Tuple(rsize))), ) # return the 1-indexed index - return (; - values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices, - ) + return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), indices) end @noinline function iota( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 478cbde43..cd224c228 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -694,7 +694,9 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs) end # sort -Base.sort(x::AnyTracedRArray; kwargs...) = sort!(copy(x); kwargs...) +function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, kwargs...) +end function Base.sort!( x::AnyTracedRArray; @@ -702,15 +704,21 @@ function Base.sort!( lt=isless, by=identity, rev::Bool=false, - kwargs..., # TODO: implement `order` and `alg` kwargs + alg=missing, + order=missing, ) + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" + comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) res = Ops.sort(materialize_traced_array(x); dimension=dims, comparator) set_mlir_data!(x, get_mlir_data(res)) return x end -Base.sortperm(x::AnyTracedRArray; kwargs...) = sortperm!(similar(x, Int), x; kwargs...) +function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, kwargs...) +end function Base.sortperm!( ix::AnyTracedRArray{Int,N}, @@ -719,8 +727,12 @@ function Base.sortperm!( lt=isless, by=identity, rev::Bool=false, - kwargs..., # TODO: implement `order` and `alg` kwargs + alg=missing, + order=missing, ) where {N} + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`" + comparator = rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) idxs = Ops.constant(collect(LinearIndices(x))) @@ -743,19 +755,59 @@ function Base.partialsort!( # TODO: general `lt` support @assert lt === isless "Only `isless` is supported for now in `partialsort!`" + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? + by_x = by.(x) + if k isa Integer + !rev && (k = length(x) - k + 1) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) + res = by === identity ? @allowscalar(values[k]) : @allowscalar(x[indices[k]]) + @allowscalar setindex!(ix, res, k) + return res + else + klist = collect(Int64, k) + !rev && (klist = length(x) .- klist .+ 1) + maxk = maximum(klist) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) + res = by === identity ? values[klist] : x[indices[klist]] + setindex!(ix, res, klist) + return res + end +end + +function Base.partialsortperm( + x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... +) + return partialsortperm!(similar(x, Int), x, k; kwargs...) +end + +function Base.partialsortperm!( + ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + by=identity, + rev::Bool=false, + lt=isless, +) + # TODO: general `lt` support + @assert lt === isless "Only `isless` is supported for now in `partialsortperm!`" + by_x = by.(x) + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? if k isa Integer !rev && (k = length(x) - k + 1) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) - by === identity && return @allowscalar values[k] - return @allowscalar x[indices[k]] + indices = Ops.convert(TracedRArray{Int64,1}, indices) + idx = @allowscalar indices[k] + @allowscalar setindex!(ix, idx, k) + return idx else klist = collect(Int64, k) !rev && (klist = length(x) .- klist .+ 1) maxk = maximum(klist) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) - by === identity && return values[klist] - return x[indices[klist]] + indices = Ops.convert(TracedRArray{Int64,1}, indices) + setindex!(ix, indices[klist], klist) + return indices[klist] end end