Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Atomic attempts #282

Closed
wants to merge 1 commit into from
Closed

Atomic attempts #282

wants to merge 1 commit into from

Conversation

leios
Copy link
Contributor

@leios leios commented Jan 11, 2022

This is a draft of an atomic update to Kernelabstractions.

I plan to put everything we need in the atomics.jl file (and corresponding CUDAKernels file); however, I cannot really test ROCM, so I might need to leave that to someone else.

Current roadmap (to be worked on throughout the week):

I am actually currently struggling with the final point because for some reason the macro I created (KernelAbstractions.@atomic) is only grabbing the first symbol of an expression and not the full expression. If everyone is happy enough with the atomic primitives, I might decide to leave the macro to future work (tm).

This is a step towards finalizing #7 and #276; however, I am not sure if it fixes them completely without the @atomic macro.

@vchuravy
Copy link
Member

Two comments:

I think the approach is the right one. Define the atomic primitives in KA and then expand them in each backend. We will need to copy the macro implementation from CUDA.jl (similar to @print).

I guess the big question is what to do on 1.6, but let me worry about that. For now just add a version check and an error for 1.6

@leios
Copy link
Contributor Author

leios commented Jan 11, 2022

Alright, I'll get to it with the other atomic functions. My goal is to have them done by the end of the week.

src/atomics.jl Outdated Show resolved Hide resolved
@leios
Copy link
Contributor Author

leios commented Jan 11, 2022

Ok, the last ones are atomic inc: ((old >= val) ? 0 : (old+1)), and atomic dec: (((old == 0) | (old > val)) ? val : (old-1) )

These just increment up to or decrement down to a specified value. I guess they are mainly for modular arithmetic (otherwise, why would they set the old val to 0 or val)?

Anyway. I think I made some good progress today. Will start on the tests tomorrow!

@leios
Copy link
Contributor Author

leios commented Jan 12, 2022

Right, so the atomic_dec tests are different in cuda and KA

CUDA:

@testset "atomic_dec" begin
    @testset for T in [Int32]
        a = CuArray(T[1024])

        function kernel(a, b)
            CUDA.atomic_dec!(pointer(a), b)
            return
        end

        @cuda threads=256 kernel(a, T(512))
        @test Array(a)[1] == 257
    end
end

KA

    @testset "atomic dec tests" begin
        types = [Int32]

        for T in types
            A = ArrayT{T}([1024,1024])

            kernel = atomic_inc_kernel(backend(), 4)
            wait(kernel(A, T(512), ndrange=(256)))

            @test Array(A)[2] == 255
        end
    end

The output is 255 in KA instead of 257 in CUDA. I did this with CUDA.atomic_dec! in KA, but couldn't quite figure out how to get 257 in KA.

I am an idiot. Fixed this.

@leios
Copy link
Contributor Author

leios commented Jan 12, 2022

All GPU tests pass, but the compare-and-swap on the CPU is still failing with nested task error: ConcurrencyViolationError("invalid atomic ordering")

@leios
Copy link
Contributor Author

leios commented Jan 12, 2022

Ok, all tests on the CPU and GPU now pass. They are mostly direct copies of the tests from CUDA.jl.

I need to:

  1. create docstrings
  2. remove attempt at @atomic macro (to be dealt with in a later PR)

@leios leios marked this pull request as ready for review January 12, 2022 19:39
@vchuravy
Copy link
Member

Rebase on master?

@leios
Copy link
Contributor Author

leios commented Jan 13, 2022

Pushed and also changed the histogram test. I think the big problem is that atomics are broken on anything < 1.7

Is there a way to exclude tests / examples and also prevent the exporting of atomics for < 1.7?

@leios
Copy link
Contributor Author

leios commented Jan 16, 2022

I added an error for running atomics on the CPU for julia versions less than 1.7.0 and got most of the tests to pass. For some reason, some of the CI is failing after the printing tests, but I don't think that has to do with this PR?

One caveat here is that I needed to remove some architecture-based precision tests, because the only way I could specify them in the test file was by relying on CUDA, which caused some CPU tests to fail. For example, certain cards cannot do the float64 atomic add, but these are tested in CUDA.jl, so I think it's probably fine since we are calling the same functions.

@oschulz
Copy link
Collaborator

oschulz commented May 18, 2022

Thanks for pointing me to this PR @vchuravy !

We would have an immediate application for this: We're implementing monotonous splines for normalizing flows based on KernelAbstractions. We have a running prototype, but defining the ChainRulesCore.rrule for ML parameter optimization is proving tricky. The spline will be computed for a (very) large number of data points and for several splines in parallel, and the data points of will hit different knots of different splines. So we'll need to add contributions to the gradients of the spline parameters (stored as matrices) in a parallel and fairly unpredictable fashion. Having efficient atomic adds would solve this.

CC @VasylHafych, @Micki-D

@leios
Copy link
Contributor Author

leios commented May 18, 2022

For the record, I am currently using this branch to do atomic operations for one of my own projects, so feel free to work off of it in the short-term.

In the long-term, a few other people are working on better atomic support in CUDA. Namely: JuliaLLVM/LLVM.jl#308. The current plan is to create a separate package (UnsafeAtomicsLLVM.jl), which we can load directly into KernelAbstractions for atomic support.

This PR will hopefully reflect these changes when they happen

@tkf
Copy link
Contributor

tkf commented May 18, 2022

With JuliaConcurrent/UnsafeAtomicsLLVM.jl#3 we can use atomics on CPU and GPU using the same interface. Once UnsafeAtomicsLLVM.jl is released (requires LLVM.jl 4.12), the ecosystem surrounding KernelAbstractions.jl would look like something like

graph TD;
    UnsafeAtomics.jl --> Atomix.jl;
    UnsafeAtomics.jl --> UnsafeAtomicsLLVM.jl;
    Atomix.jl --> KernelAbstractions.jl;
    KernelAbstractions.jl --> CUDAKernels.jl;
    LLVM.jl --> UnsafeAtomicsLLVM.jl;
    LLVM.jl --> CUDA.jl;
    CUDA.jl --> CUDAKernels.jl;
    UnsafeAtomicsLLVM.jl --> CUDAKernels.jl;
    KernelAbstractions.jl --> user[User code]
    CUDAKernels.jl --> user[User code]
Loading

where

  • LLVM.jl provides "hardware agnostic" atomic intrinsics (mostly used for GPUs)
  • UnsafeAtomics.jl provides low-level pointer-based atomic API
  • UnsafeAtomicsLLVM.jl wraps LLVM.jl intrinsics in UnsafeAtomics.jl API
  • Atomix.jl provides user-facing API such as @atomic A[i] += 1 by extending the macros from Base and lowering the expressions to the API calls based on UnsafeAtomics.jl
  • CUDAKernels.jl loads LLVM.jl-based implementation of atomics by importing UnsafeAtomicsLLVM.jl
  • KernelAbstractions.jl re-exports Atomix.jl API

@oschulz
Copy link
Collaborator

oschulz commented May 18, 2022

For the record, I am currently using this branch to do atomic operations for one of my own projects, so feel free to work off of it in the short-term.

We'll give it a try!

With JuliaConcurrent/UnsafeAtomicsLLVM.jl#3 we can use atomics on CPU and GPU [...]

That'll be awesome ...

Thanks @leios and @tkf !

@oschulz
Copy link
Collaborator

oschulz commented Jun 7, 2022

With JuliaConcurrent/UnsafeAtomicsLLVM.jl#3 we can use atomics on CPU and GPU [...]

@tfk, I saw a UnsafeAtomicsLLVM 0.1.0 is out now - can that be used (on Julia v1.8) in a KernelAbstractions kernel already?

@vchuravy
Copy link
Member

vchuravy commented Jun 8, 2022

@leios I think we can close this now?

@leios
Copy link
Contributor Author

leios commented Jun 8, 2022

Yeah, even if we rework it, I think it needs to be in a different PR

@leios leios closed this Jun 8, 2022
@oschulz
Copy link
Collaborator

oschulz commented Jun 8, 2022

I think we can close this now?

Does that mean UnsafeAtomicsLLVM is "ready", basically?

@leios
Copy link
Contributor Author

leios commented Jun 8, 2022

Yeah, basically. There might still be some stuff to sort out, but #299 added the dependency on UnsafeAtomicsLLVM and is working in master (though we are missing a few tests and docs).

@oschulz
Copy link
Collaborator

oschulz commented Jun 9, 2022

Neat, thanks @leios ! Is there an example or so, to get started?

@leios
Copy link
Contributor Author

leios commented Jun 9, 2022

Right now, it's just the histogram example: https://github.com/JuliaGPU/KernelAbstractions.jl/blob/master/examples/histogram.jl; however, #299 has someone commenting about using it for an @atomic max(...) as well.

We need better examples and docs...

@oschulz
Copy link
Collaborator

oschulz commented Jun 9, 2022

Thanks! So it's basically just @atomic A[i] += x, right? Does it need to be on @localmem?

@leios
Copy link
Contributor Author

leios commented Jun 9, 2022

Ah, no. That was just a micro-optimization (and a proof that it works on shared memory). Any @atomic call should be fine within a kernel (so far as I am aware...).

@oschulz
Copy link
Collaborator

oschulz commented Jun 9, 2022

Thanks a lot @leios! @VasylHafych and me will give it a go for our scattered gradient accumulation use case.

This was referenced Jun 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants