Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCM-Aware MPI requires AMDGPU.synchronize() #2591

Open
Alexander-Barth opened this issue Feb 18, 2025 · 1 comment
Open

ROCM-Aware MPI requires AMDGPU.synchronize() #2591

Alexander-Barth opened this issue Feb 18, 2025 · 1 comment

Comments

@Alexander-Barth
Copy link

Alexander-Barth commented Feb 18, 2025

When using Distributed Data Parallel (DDP) with two AMD GPUs communicating via ROCM-aware MPI, AMDGPU.synchronize() is necessary at different steps otherwise the state of the optimizer is inconsistent or the averaged gradients are wrong.
This is a follow-up from this dicussion:

https://discourse.julialang.org/t/distributed-data-parallel-training-with-2-gpus-fails-with-flux-jl-on-amd-gpus/125993/6

The serial code (using SERIAL=true) works as expected:

serial = get(ENV,"SERIAL","false") == "true"

import AMDGPU
using Flux
using Optimisers
using Zygote
using Statistics
using Random
if !serial
    import MPI
end

Random.seed!(42)

function pprintln(backend,args...)
    MPI.Barrier(backend.comm)
    print("rank ",DistributedUtils.local_rank(backend),": ")
    println(args...)
end
pprintln(::Nothing,args...) = println(args...)

AMDGPU.allowscalar(false)

@show Flux.MPI_ROCM_AWARE

if !serial
    const backend_type = MPIBackend
    DistributedUtils.initialize(backend_type)
    backend = DistributedUtils.get_distributed_backend(backend_type)
else
    backend = nothing
end

T = Float32
device = gpu

x = randn(T,256,256,32,16*2) |> device

channels = 2 .^ vcat(5:7,6:-1:5)

model = Chain(
    [Conv((3,3),channels[i] => channels[i+1],pad=SamePad(),selu) for i in 1:length(channels)-1]...
)

losses = T[]
model = model |> device

loss(x,y) = mean((x-y).^2)

opt_s = Optimisers.Adam(1f-4)
#opt_s = Optimisers.Descent(0.01f0) # ok

if !serial
    data = DistributedUtils.DistributedDataContainer(backend, x)
    model = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(model); root=0)
    opt = DistributedUtils.DistributedOptimizer(backend, opt_s)
else
    data = x
    opt = opt_s
end

opt_state = Optimisers.setup(opt, model)

AMDGPU.synchronize() # necessary

if !serial
    opt_state = DistributedUtils.synchronize!!(backend, opt_state; root=0)
end

dl = Flux.DataLoader(data,batchsize=16)


for i = 1:1000
    global model, opt_state
    for (j,x_batch) in enumerate(dl)
        val, grads = Flux.withgradient(model) do m
            loss(x_batch,m(x_batch))
        end

        AMDGPU.synchronize() # necessary

        push!(losses, val)
        opt_state, model = Optimisers.update(opt_state, model, grads[1])
#        pprintln(backend,"update ",i," ",model.layers[1].weight[1:1])
    end
end

pprintln(backend,"losses ",losses)

The output is of this program without AMDGPU.synchronize() is:

Flux.MPI_ROCM_AWARE = true
Flux.MPI_ROCM_AWARE = true
rank 1: update 1 Float32[0.005779413]
rank 0: update 1 Float32[0.005779413]
rank 1: update 2 Float32[NaN]
rank 0: update 2 Float32[0.0056868508]
rank 1: update 3 Float32[NaN]
rank 0: update 3 Float32[0.005596291]
rank 1: update 4 Float32[NaN]
rank 0: update 4 Float32[0.0055066617]
rank 1: lossesFloat32[2.0405662, NaN, NaN, NaN]
rank 0: lossesFloat32[2.040605, 2.0056882, 1.9737886, 1.9429569]

My environment:

julia 1.11.2

⌃ [21141c5a] AMDGPU v1.2.2
  [0a1fb500] BlockDiagonals v0.1.42
  [052768ef] CUDA v5.6.1
⌃ [13f3f980] CairoMakie v0.12.18
⌃ [b0b7db55] ComponentArrays v0.15.22
  [efc8151c] DIVAnd v2.7.12
  [cf87cc76] DataAssim v0.4.1
  [8bb1440f] DelimitedFiles v1.9.1
  [4e2335b7] FlowMatching v0.1.0 `..`
  [587475ba] Flux v0.16.3 `~/.julia/dev/Flux`
  [db073c08] GeoMakie v0.7.10
  [033835bb] JLD2 v0.5.11
  [f1d291b0] MLUtils v0.4.7
  [da04e1cc] MPI v0.20.22
  [3da0fdf6] MPIPreferences v0.1.11
  [85f8d34a] NCDatasets v0.14.6
  [3bd65402] Optimisers v0.4.4
  [21216c6a] Preferences v1.4.3
  [10745b16] Statistics v1.11.1
⌃ [e88e6eb3] Zygote v0.7.3 or Zygote v0.7.4
  [02a925ec] cuDNN v1.4.1
  [ade2ca70] Dates v1.11.0
  [de0858da] Printf v1.11.0
  [8dfed614] Test v1.11.0

Just using MPI and AMDGPU, we can see that without AMDGPU.synchronize(), the send message is wrong in this example:

using MPI
using AMDGPU
MPI.Init()
comm = MPI.COMM_WORLD
rank = MPI.Comm_rank(comm)
# select device
comm_l = MPI.Comm_split_type(comm, MPI.COMM_TYPE_SHARED, rank)
rank_l = MPI.Comm_rank(comm_l)
device = AMDGPU.device_id!(rank_l+1)
gpu_id = AMDGPU.device_id(AMDGPU.device())
# select device
size = MPI.Comm_size(comm)
dst  = mod(rank+1, size)
src  = mod(rank-1, size)
println("rank=$rank rank_loc=$rank_l (gpu_id=$gpu_id - $device), size=$size, dst=$dst, src=$src")
N = 4
send_mesg = ROCArray{Float64}(undef, N)
recv_mesg = ROCArray{Float64}(undef, N)
fill!(send_mesg, Float64(rank))
send_mesg .+= 1

AMDGPU.synchronize() # necessary

MPI.Sendrecv!(send_mesg, dst, 0, recv_mesg, src, 0, comm)

if rank == 0
    println("got ",Array(recv_mesg))
    println("correct: ",all(Array(recv_mesg) .== (src+1)))
end

Rank zero gets the correct message only 2 out of 20 tries. With AMDGPU.synchronize() all received messages are correct.

Thanks to @pxl-th for suggesting that this is a synchronization issue.

@pxl-th
Copy link
Member

pxl-th commented Feb 18, 2025

Just for context, @Alexander-Barth said that with Lux it works fine without manually synchronizing, so looks like something is missing here in Flux.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants