Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not support partialsortperm ? #485

Closed
x66ccff opened this issue Jan 6, 2025 · 7 comments · Fixed by #529
Closed

Not support partialsortperm ? #485

x66ccff opened this issue Jan 6, 2025 · 7 comments · Fixed by #529
Labels
good first issue Good for newcomers

Comments

@x66ccff
Copy link
Contributor

x66ccff commented Jan 6, 2025

Code

using Reactant
x = rand(Float32, 1000)
xr = Reactant.to_rarray(x)
partialsortperm_compiled = @compile partialsortperm(xr, 1:100)

Outputs

ERROR: Scalar indexing is disallowed.
Invocation of getindex(::TracedRArray, ::Vararg{Int, N}) resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error
    @ ./error.jl:35 [inlined]
  [2] error(none::String)
    @ Reactant ./<missing>:0
  [3] ErrorException
    @ ./boot.jl:323 [inlined]
  [4] error
    @ ./error.jl:35 [inlined]
  [5] call_with_reactant(::typeof(error), ::String)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
  [6] errorscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155 [inlined]
  [7] errorscalar(none::String)
    @ Reactant ./<missing>:0
  [8] scalardesc
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:138 [inlined]
  [9] errorscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:154 [inlined]
 [10] call_with_reactant(::typeof(GPUArraysCore.errorscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [11] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128 [inlined]
 [12] _assertscalar(none::String, none::GPUArraysCore.ScalarIndexing)
    @ Reactant ./<missing>:0
 [13] _assertscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:127 [inlined]
 [14] call_with_reactant(::typeof(GPUArraysCore._assertscalar), ::String, ::GPUArraysCore.ScalarIndexing)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [15] assertscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116 [inlined]
 [16] assertscalar(none::String)
    @ Reactant ./<missing>:0
 [17] current_task
    @ ./task.jl:152 [inlined]
 [18] task_local_storage
    @ ./task.jl:280 [inlined]
 [19] assertscalar
    @ ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:101 [inlined]
 [20] call_with_reactant(::typeof(GPUArraysCore.assertscalar), ::String)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [21] getindex
    @ ~/.julia/packages/Reactant/7m11i/src/TracedRArray.jl:45 [inlined]
 [22] getindex(none::Reactant.TracedRArray{Float32, 1}, none::Tuple{Int64})
    @ Reactant ./<missing>:0
 [23] getindex
    @ ~/.julia/packages/Reactant/7m11i/src/TracedRArray.jl:45 [inlined]
 [24] call_with_reactant(::typeof(getindex), ::Reactant.TracedRArray{Float32, 1}, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [25] lt
    @ ./ordering.jl:124 [inlined]
 [26] partition!
    @ ./sort.jl:1079 [inlined]
 [27] partition!(none::Vector{…}, none::Int64, none::Int64, none::Int64, none::Base.Order.Perm{…}, none::Vector{…}, none::Bool, none::Vector{…}, none::Int64)
    @ Reactant ./<missing>:0
 [28] hash
    @ ./hashing.jl:88 [inlined]
 [29] hash
    @ ./hashing.jl:30 [inlined]
 [30] partition!
    @ ./sort.jl:1074 [inlined]
 [31] call_with_reactant(::typeof(Base.Sort.partition!), ::Vector{…}, ::Int64, ::Int64, ::Int64, ::Base.Order.Perm{…}, ::Vector{…}, ::Bool, ::Vector{…}, ::Int64)
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [32] #_sort!#25
    @ ./sort.jl:1117 [inlined]
 [33] var"#_sort!#25"(none::Nothing, none::Nothing, none::Bool, none::Bool, none::typeof(Base.Sort._sort!), none::Vector{…}, none::Base.Sort.ScratchQuickSort{…}, none::Base.Order.Perm{…}, none::@NamedTuple{…})
    @ Reactant ./<missing>:0
 [34] get
    @ ./namedtuple.jl:388 [inlined]
 [35] _lo
    @ ./sort.jl:480 [inlined]
 [36] #_sort!#25
    @ ./sort.jl:1105 [inlined]
 [37] call_with_reactant(::Base.Sort.var"##_sort!#25", ::Nothing, ::Nothing, ::Bool, ::Bool, ::typeof(Base.Sort._sort!), ::Vector{…}, ::Base.Sort.ScratchQuickSort{…}, ::Base.Order.Perm{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [38] _sort!
    @ ./sort.jl:1103 [inlined]
 [39] _sort!
    @ ./sort.jl:721 [inlined]
 [40] _sort!
    @ ./sort.jl:792 [inlined]
 [41] _sort!(none::Vector{…}, none::Base.Sort.Small{…}, none::Base.Order.Perm{…}, none::@NamedTuple{…})
    @ Reactant ./<missing>:0
 [42] get
    @ ./namedtuple.jl:388 [inlined]
 [43] _lo
    @ ./sort.jl:480 [inlined]
 [44] _sort!
    @ ./sort.jl:788 [inlined]
 [45] call_with_reactant(::typeof(Base.Sort._sort!), ::Vector{…}, ::Base.Sort.Small{…}, ::Base.Order.Perm{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [46] _sort!
    @ ./sort.jl:737 [inlined]
 [47] _sort!
    @ ./sort.jl:676 [inlined]
 [48] _sort!
    @ ./sort.jl:554 [inlined]
 [49] #partialsortperm!#35
    @ ./sort.jl:1814 [inlined]
 [50] var"#partialsortperm!#35"(none::Function, none::Function, none::Nothing, none::Base.Order.ForwardOrdering, none::Bool, none::typeof(partialsortperm!), none::Vector{…}, none::Reactant.TracedRArray{…}, none::UnitRange{…})
    @ Reactant ./<missing>:0
 [51] size
    @ array.jl:194 [inlined]
 [52] axes
    @ abstractarray.jl:98 [inlined]
 [53] axes
    @ abstractarray.jl:77 [inlined]
 [54] call_with_reactant(::Base.Sort.var"##partialsortperm!#35", ::Function, ::Function, ::Nothing, ::Base.Order.ForwardOrdering, ::Bool, ::typeof(partialsortperm!), ::Vector{…}, ::Reactant.TracedRArray{…}, ::UnitRange{…})
    @ Reactant sort.jl:1805
 [55] partialsortperm!
    @ ./sort.jl:1798 [inlined]
 [56] partialsortperm
    @ ./sort.jl:1755 [inlined]
 [57] partialsortperm(none::Reactant.TracedRArray{Float32, 1}, none::UnitRange{Int64})
    @ Reactant ./<missing>:0
 [58] getproperty
    @ ./Base.jl:49 [inlined]
 [59] size
    @ ~/.julia/packages/Reactant/7m11i/src/TracedRArray.jl:224 [inlined]
 [60] axes
    @ ./abstractarray.jl:98 [inlined]
 [61] axes
    @ ./abstractarray.jl:77 [inlined]
 [62] partialsortperm
    @ ./sort.jl:1755 [inlined]
 [63] call_with_reactant(::typeof(partialsortperm), ::Reactant.TracedRArray{Float32, 1}, ::UnitRange{Int64})
    @ Reactant ~/.julia/packages/Reactant/7m11i/src/utils.jl:0
 [64] (::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(partialsortperm), Tuple{…}, Vector{…}, Tuple{…}})()
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/7m11i/src/TracedUtils.jl:182
 [65] block!(f::Reactant.TracedUtils.var"#8#18"{…}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/7m11i/src/mlir/IR/Block.jl:201
 [66] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/7m11i/src/TracedUtils.jl:169
 [67] make_mlir_fn
    @ ~/.julia/packages/Reactant/7m11i/src/TracedUtils.jl:86 [inlined]
 [68] #10
    @ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:330 [inlined]
 [69] block!(f::Reactant.Compiler.var"#10#15"{typeof(partialsortperm), Tuple{…}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/7m11i/src/mlir/IR/Block.jl:201
 [70] #9
    @ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:329 [inlined]
 [71] mmodule!(f::Reactant.Compiler.var"#9#14"{…}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/7m11i/src/mlir/IR/Module.jl:92
 [72] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{…}, UnitRange{…}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:326
 [73] compile_mlir!
    @ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:325 [inlined]
 [74] (::Reactant.Compiler.var"#32#34"{Bool, typeof(partialsortperm), Tuple{ConcreteRArray{Float32, 1}, UnitRange{Int64}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:820
 [75] context!(f::Reactant.Compiler.var"#32#34"{Bool, typeof(partialsortperm), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/7m11i/src/mlir/IR/Context.jl:76
 [76] compile_xla(f::Function, args::Tuple{ConcreteRArray{Float32, 1}, UnitRange{Int64}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:817
 [77] compile_xla
    @ ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:812 [inlined]
 [78] compile(f::Function, args::Tuple{ConcreteRArray{…}, UnitRange{…}}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/7m11i/src/Compiler.jl:844
Some type information was truncated. Use `show(err)` to see complete types.
@wsmoses
Copy link
Member

wsmoses commented Jan 6, 2025

@glou-nes per your work on #374 do you want to give this a go?

@mofeing
Copy link
Collaborator

mofeing commented Jan 6, 2025

i think in this case we might to use chlo.top_k which is already implemented as Ops.top_k.

@x66ccff
Copy link
Contributor Author

x66ccff commented Jan 12, 2025

you are right, just use Ops.top_k will solve this

At first, I didn't understand what you meant by Ops and chlo. I think it would be better to mention this in the documentation.

Reactant.jl/test/ops.jl

Lines 917 to 920 in ca98c17

@testset "top_k" begin
x = ConcreteRArray([1, 2, 3, 4])
@test (; values=[4, 3], indices=[3, 2]) == @jit Ops.top_k(x, 2)
end

julia> using Reactant

julia> x = ConcreteRArray([123,456,789,121])
4-element ConcreteRArray{Int64, 1}:
 123
 456
 789
 121

julia> f_top2_from4 = @compile Reactant.Ops.top_k(x, 2)
2025-01-12 21:27:53.904248: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 5867536532833239782
Reactant.Compiler.Thunk{typeof(Reactant.Ops.top_k), Symbol("##top_k_reactant#264"), Tuple{ConcreteRArray{Int64, 1}, Int64}, false}(Reactant.Ops.top_k)

julia> f_top2_from4(x, 2)
(values = ConcreteRArray{Int64, 1}([789, 456]), indices = ConcreteRArray{Int32, 1}(Int32[2, 1]))

julia> 

@x66ccff x66ccff closed this as completed Jan 12, 2025
@mofeing
Copy link
Collaborator

mofeing commented Jan 12, 2025

At first, I didn't understand what you meant by Ops and chlo. I think it would be better to mention this in the documentation.

Reactant is still in a very experimental phase. We are thinking on doing a publication first and then, when it starts to stabilize, to write the documentation. The main reason why we don't have a proper documentation yet is that we are still breaking a lot of stuff and refactoring.

CHLO, StableHLO, MHLO... are like the MLIR dialects used by XLA (they didn't start as MLIR dialects, but moving into there); i.e. the primitives in the IR that XLA accepts as input. CHLO is actually a high-level dialect that lowers to StableHLO. It has no stability guarantees but they should offer that functionality always.

The MLIR.jl API requires setting up a lot of code to generate the IR. The methods in the Ops module are wrappers around all that boilerplate to emit directly MLIR ops from TracedRArrays and such. So, you should have a method in Ops for each *HLO op that directly emits it with Reactant traced types. The list is not complete and is tracked in #273.

In principle, you shouldn't call Ops.top_k but a specialized method of partialsort and partialsortperm on TracedRArray should call Ops.top_k. Anyway, I'm glad that it worked well for you, but I'm reopening the issue because those functions do not yet have the specialized method implemented.

@mofeing mofeing reopened this Jan 12, 2025
@mofeing mofeing added the good first issue Good for newcomers label Jan 12, 2025
@x66ccff
Copy link
Contributor Author

x66ccff commented Jan 12, 2025

Thank you for your detailed explanation!

@x66ccff
Copy link
Contributor Author

x66ccff commented Jan 13, 2025

Does it feel a bit non-standard? The indices returned by this start from 0, and I initially thought they started from 1.

julia> f_top2_from = @compile Reactant.Ops.top_k(x, 2)
julia> f_top2_from(ConcreteRArray([NaN,123,456,789,121]), 2)
(values = ConcreteRArray{Float64, 1}([NaN, 789.0]), indices = ConcreteRArray{Int32, 1}(Int32[0, 3]))

@mofeing
Copy link
Collaborator

mofeing commented Jan 13, 2025

mmm Julia is 1-indexed, but MLIR is 0-indexed. it could be a bug in the Ops.

@Pangoraw Pangoraw linked a pull request Jan 14, 2025 that will close this issue
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants