Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing mul! implementation #47

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
steps:
- label: "GPU integration with julia v1.6"
- label: "GPU integration with julia v1.10"
plugins:
- JuliaCI/julia#v1:
# Drop default "registries" directory, so it is not persisted from execution to execution
# Taken from https://github.com/JuliaLang/julia/blob/v1.7.2/.buildkite/pipelines/main/platforms/package_linux.yml#L11-L12
persist_depot_dirs: packages,artifacts,compiled
version: "1.6"
version: "1.10"
- JuliaCI/julia-test#v1: ~
agents:
queue: "juliagpu"
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
- 'nightly'
os:
Expand Down Expand Up @@ -47,17 +47,17 @@ jobs:

- name: "Run test without coverage report"
uses: julia-actions/julia-runtest@v1
if: ${{ !contains(fromJson('["1", "1.6"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
if: ${{ !contains(fromJson('["1", "1.10"]'), matrix.version) || matrix.os != 'ubuntu-latest' }}
with:
coverage: false

- name: "Run test with coverage report"
uses: julia-actions/julia-runtest@v1
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
- uses: julia-actions/julia-processcoverage@v1
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
- uses: codecov/codecov-action@v3
if: contains(fromJson('["1", "1.6"]'), matrix.version) && matrix.os == 'ubuntu-latest'
if: contains(fromJson('["1", "1.10"]'), matrix.version) && matrix.os == 'ubuntu-latest'
with:
files: lcov.info

Expand All @@ -68,7 +68,7 @@ jobs:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: '1.6'
version: '1.10'
- run: |
julia --project=docs -e '
using Pkg
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.6"
version = "0.2.7"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
9 changes: 9 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ for wrapper in [:Adjoint, :Transpose]
end
end
end

function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLike)
_isonehot(B) || return invoke(mul!, Tuple{AbstractArray,AbstractMatrix,AbstractMatrix}, Y, A, B)
size(A,2) == size(B,1) || throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2)) ≠ $(size(B,1))")
)
# matmul sometimes wraps in ReshapedArray, taking parent is a simple way to handle that case
copyto!(Y, view(A, :, onecold(parent(B))))
end

10 changes: 9 additions & 1 deletion test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,16 @@ end
if VERSION >= v"1.9" && CUDA.functional()
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
else
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
end

# some specialized implementations call only mul! and not *, so we must ensure this works
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) ≈ gA*y

#TODO: the below fails due to method ambiguity and GPU scalar indexing
y = reshape(y, 3, 2)
gA = rand(2, 3) |> cu
@test_broken LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) ≈ gA*y
end

@testset "onehotbatch(::CuArray, ::UnitRange)" begin
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using OneHotArrays
using Test
using Test, LinearAlgebra
using Compat: stack

@testset "OneHotArray" begin
Expand Down
Loading