Skip to content

Commit

Permalink
test: sort and partial sort functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 15, 2025
1 parent c98e2c3 commit 0313c14
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 24 deletions.
58 changes: 38 additions & 20 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,18 @@ end

function Base.sort!(
x::AnyTracedRArray;
dims::Integer,
dims::Union{Integer,Nothing}=nothing,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=missing,
)
if dims === nothing
@assert ndims(x) == 1
dims = 1
end

@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"

Expand All @@ -740,13 +745,18 @@ end
function Base.sortperm!(
ix::AnyTracedRArray{Int,N},
x::AnyTracedRArray{<:Any,N};
dims::Integer,
dims::Union{Integer,Nothing}=nothing,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=missing,
) where {N}
if dims === nothing
@assert ndims(x) == 1
dims = 1
end

@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"

Expand All @@ -761,6 +771,7 @@ end
function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
values, _ = overloaded_partialsort(x, k; kwargs...)
k = k .- minimum(k) .+ 1
k isa Integer && return @allowscalar(values[k])
return view(values, k)
end

Expand All @@ -769,7 +780,31 @@ function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange};
kget = k .- minimum(k) .+ 1
val = @allowscalar(values[kget])
@allowscalar setindex!(x, val, k)
return val
k isa Integer && return val
return view(x, k)
end

function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
)
idxs = overloaded_partialsort(x, k; kwargs...)[2]
k = k .- minimum(k) .+ 1
k isa Integer && return @allowscalar(idxs[k])
return view(idxs, k)
end

function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
kwargs...,
)
_, idxs = overloaded_partialsort(x, k; kwargs...)
kget = k .- minimum(k) .+ 1
val = @allowscalar(idxs[kget])
@allowscalar setindex!(ix, val, k)
k isa Integer && return val
return view(ix, k)
end

function overloaded_partialsort(
Expand Down Expand Up @@ -800,23 +835,6 @@ function overloaded_partialsort(
return values, indices
end

function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
)
return view(overloaded_partialsort(x, k; kwargs...)[2], k)
end

function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
kwargs...,
)
_, idxs = overloaded_partialsort(x, k; kwargs...)
@allowscalar setindex!(ix, idxs[k], k)
return view(ix, k)
end

# arg* functions
function Base.argmin(f::F, x::AnyTracedRArray) where {F}
idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1
Expand Down
106 changes: 102 additions & 4 deletions test/sorting.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,110 @@
using Reactant, Test

@testset "sort" begin end
@testset "sort & sortperm" begin
x = randn(10)
x_ra = Reactant.to_rarray(x)

@testset "sortperm" begin end
srt_rev(x) = sort(x; rev=true)
srtperm_rev(x) = sortperm(x; rev=true)
srt_by(x) = sort(x; by=abs2)
srtperm_by(x) = sortperm(x; by=abs2)
srt_lt(x) = sort(x; lt=(a, b) -> a > b)
srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b)

@test @jit(sort(x_ra)) == sort(x)
@test @jit(srt_rev(x_ra)) == srt_rev(x)
@test @jit(srt_lt(x_ra)) == srt_lt(x)
@test @jit(srt_by(x_ra)) == srt_by(x)
@test @jit(sortperm(x_ra)) == sortperm(x)
@test @jit(srtperm_rev(x_ra)) == srtperm_rev(x)
@test @jit(srtperm_lt(x_ra)) == srtperm_lt(x)
@test @jit(srtperm_by(x_ra)) == srtperm_by(x)

x = rand(10)
x_ra = Reactant.to_rarray(x)
@jit sort!(x_ra)
@test x_ra == sort(x)

@testset "partialsort" begin end
x = rand(10)
x_ra = Reactant.to_rarray(x)
ix = similar(x_ra, Int)
@jit sortperm!(ix, x_ra)
@test ix == sortperm(x)

x = rand(10, 4, 3)
x_ra = Reactant.to_rarray(x)

@testset "partialsortperm" begin end
srt(x, d) = sort(x; dims=d)
srt_rev(x, d) = sort(x; dims=d, rev=true)
srt_by(x, d) = sort(x; dims=d, by=abs2)
srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b)
srtperm(x, d) = sortperm(x; dims=d)
srtperm_rev(x, d) = sortperm(x; dims=d, rev=true)
srtperm_by(x, d) = sortperm(x; dims=d, by=abs2)
srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b)

@testset for d in 1:ndims(x)
@test @jit(srt(x_ra, d)) == srt(x, d)
@test @jit(srtperm(x_ra, d)) == srtperm(x, d)
@test @jit(srt_rev(x_ra, d)) == srt_rev(x, d)
@test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d)
@test @jit(srt_by(x_ra, d)) == srt_by(x, d)
@test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d)
@test @jit(srt_lt(x_ra, d)) == srt_lt(x, d)
@test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d)
end
end

@testset "partialsort & partialsortperm" begin
x = randn(10)
x_ra = Reactant.to_rarray(x)

@test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5)
@test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5)
@test @jit(partialsort(x_ra, 4)) == partialsort(x, 4)
@test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4)

psrt_rev(x, k) = partialsort(x, k; rev=true)
psrtperm_rev(x, k) = partialsortperm(x, k; rev=true)
psrt_by(x, k) = partialsort(x, k; by=abs2)
psrtperm_by(x, k) = partialsortperm(x, k; by=abs2)
psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b)
psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b)

@test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5)
@test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5)
@test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5)
@test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5)
@test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5)
@test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5)

x = randn(10)
x_ra = Reactant.to_rarray(x)
@jit partialsort!(x_ra, 1:5)
partialsort!(x, 1:5)
@test x_ra[1:5] == x[1:5]

x = randn(10)
x_ra = Reactant.to_rarray(x)
@jit partialsort!(x_ra, 3)
partialsort!(x, 3)
@test @allowscalar(x_ra[3]) == x[3]

x = randn(10)
x_ra = Reactant.to_rarray(x)

ix = similar(x_ra, Int)
ix_ra = Reactant.to_rarray(ix)
@jit partialsortperm!(ix_ra, x_ra, 1:5)
partialsortperm!(ix, x, 1:5)
@test ix_ra[1:5] == ix[1:5]

ix = similar(x_ra, Int)
ix_ra = Reactant.to_rarray(ix)
@jit partialsortperm!(ix_ra, x_ra, 3)
partialsortperm!(ix, x, 3)
@test @allowscalar(ix_ra[3]) == ix[3]
end

@testset "argmin / argmax" begin
x = rand(2, 3)
Expand Down

0 comments on commit 0313c14

Please sign in to comment.