From 17ea1d6dfb293e397be02894c9b50d54e1817a09 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 15:16:49 -0500 Subject: [PATCH] feat: keep lazy indexing --- src/TracedRArray.jl | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index cd224c228..326d6acc3 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -777,11 +777,28 @@ end function Base.partialsortperm( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... ) - return partialsortperm!(similar(x, Int), x, k; kwargs...) + idxs = overloaded_partialsortperm(x, k; kwargs...) + k isa Integer && return @allowscalar idxs[k] + return view(idxs, k) end function Base.partialsortperm!( ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; kwargs... +) + idxs = overloaded_partialsortperm(x, k; kwargs...) + + if k isa Integer + @allowscalar setindex!(ix, idxs[k], k) + return idxs + else + setindex!(ix, idxs[k], k) + return view(ix, k) + end +end + +function overloaded_partialsortperm( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; by=identity, @@ -791,24 +808,11 @@ function Base.partialsortperm!( # TODO: general `lt` support @assert lt === isless "Only `isless` is supported for now in `partialsortperm!`" - by_x = by.(x) - # XXX: If `maxk` is beyond a threshold should we emit a sort directly? - if k isa Integer - !rev && (k = length(x) - k + 1) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) - indices = Ops.convert(TracedRArray{Int64,1}, indices) - idx = @allowscalar indices[k] - @allowscalar setindex!(ix, idx, k) - return idx - else - klist = collect(Int64, k) - !rev && (klist = length(x) .- klist .+ 1) - maxk = maximum(klist) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) - indices = Ops.convert(TracedRArray{Int64,1}, indices) - setindex!(ix, indices[klist], klist) - return indices[klist] - end + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg + !rev && (k = length(x) .- k .+ 1) + !(k isa Integer) && (k = maximum(k)) + (; indices) = Ops.top_k(materialize_traced_array(by.(x)), k) + return Ops.convert(TracedRArray{Int64,1}, indices) end function Base.argmin(x::AnyTracedRArray; kwargs...)