diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 29a9a2874..d9bf97eb8 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -397,4 +397,24 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr return C end +function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T} + if length(x) != length(y) + throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) + end + ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) + + set_mlir_data!(y, get_mlir_data(Ops.add(y, ax))) +end + +function LinearAlgebra.axpby!(α::Number, x::TracedRArray{T}, β::Number, y::TracedRArray{T}) where {T} + if length(x) != length(y) + throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) + end + ax = Ops.multiply(x, TracedUtils.broadcast_to_size(T(α), size(x))) + by = Ops.multiply(y, TracedUtils.broadcast_to_size(T(β), size(y))) + + set_mlir_data!(y, get_mlir_data(Ops.add(ax, by))) +end + + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cd804d150..114b35252 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -183,3 +183,87 @@ end end end end + +@testset "axpy!" begin + α = 3 + x = rand(Int64, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Int64, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) + + α = 2 + x = rand(4) + x_ra = Reactant.to_rarray(x) + y = rand(4) + y_ra = Reactant.to_rarray(y) + + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) + + α = 4.12 + X = rand(3, 5) + Y = rand(3, 5) + X_ra = Reactant.to_rarray(X) + Y_ra = Reactant.to_rarray(Y) + + @jit axpy!(α, X_ra, Y_ra) + @test Y_ra ≈ axpy!(α, X, Y) + + α = 3.2 + 1im + x = rand(Complex{Float32}, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Complex{Float32}, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpy!(α, x_ra, y_ra) + @test y_ra ≈ axpy!(α, x, y) + +end + +@testset "axpby!" begin + α = 3 + β = 2 + x = rand(Int64, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Int64, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) + + α = 2 + β = 3 + x = rand(4) + x_ra = Reactant.to_rarray(x) + y = rand(4) + y_ra = Reactant.to_rarray(y) + + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) + + α = 4.12 + X = rand(3, 5) + Y = rand(3, 5) + X_ra = Reactant.to_rarray(X) + Y_ra = Reactant.to_rarray(Y) + + @jit axpby!(α, X_ra, β, Y_ra) + @test Y_ra ≈ axpby!(α, X, β, Y) + + α = 3.2 + 1im + β = 2.1 - 4.2im + x = rand(Complex{Float32}, 4) + x_ra = Reactant.to_rarray(x) + y = rand(Complex{Float32}, 4) + y_ra = Reactant.to_rarray(y) + + @jit axpby!(α, x_ra, β, y_ra) + @test y_ra ≈ axpby!(α, x, β, y) + +end + + +