Skip to content

Commit

Permalink
feat: keep lazy indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 14, 2025
1 parent 2172ff2 commit 17ea1d6
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,28 @@ end
function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
)
return partialsortperm!(similar(x, Int), x, k; kwargs...)
idxs = overloaded_partialsortperm(x, k; kwargs...)
k isa Integer && return @allowscalar idxs[k]
return view(idxs, k)
end

function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange}; kwargs...
)
idxs = overloaded_partialsortperm(x, k; kwargs...)

if k isa Integer
@allowscalar setindex!(ix, idxs[k], k)
return idxs
else
setindex!(ix, idxs[k], k)
return view(ix, k)
end
end

function overloaded_partialsortperm(
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
by=identity,
Expand All @@ -791,24 +808,11 @@ function Base.partialsortperm!(
# TODO: general `lt` support
@assert lt === isless "Only `isless` is supported for now in `partialsortperm!`"

by_x = by.(x)
# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
if k isa Integer
!rev && (k = length(x) - k + 1)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), k)
indices = Ops.convert(TracedRArray{Int64,1}, indices)
idx = @allowscalar indices[k]
@allowscalar setindex!(ix, idx, k)
return idx
else
klist = collect(Int64, k)
!rev && (klist = length(x) .- klist .+ 1)
maxk = maximum(klist)
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk)
indices = Ops.convert(TracedRArray{Int64,1}, indices)
setindex!(ix, indices[klist], klist)
return indices[klist]
end
# XXX: If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg
!rev && (k = length(x) .- k .+ 1)
!(k isa Integer) && (k = maximum(k))
(; indices) = Ops.top_k(materialize_traced_array(by.(x)), k)
return Ops.convert(TracedRArray{Int64,1}, indices)
end

function Base.argmin(x::AnyTracedRArray; kwargs...)
Expand Down

0 comments on commit 17ea1d6

Please sign in to comment.