Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pmean inside objax.parallel causes multithreading deadlock for more than 2 gpus #245

Open
a1302z opened this issue Aug 22, 2022 · 3 comments

Comments

@a1302z
Copy link

a1302z commented Aug 22, 2022

Hi,
I've noticed a problem, where I'd like to ask for your expertise. I'm not entirely sure if it is an objax problem or rather a Jax problem under the hood, but as it is triggered by objax commands I'll post it here.

Description

In particular, when combining objax.Parallel and objax.functional.pmean (as done in this tutorial) I encounter problems with more than 2 GPUs (with 2 GPUs it works fine). It results in a deadlock situation, where nothing happens anymore. If I understand the tutorial correctly, the pmean is necessary to average the gradients of all cards.

Minimal reproducible example

import objax
import numpy as np
from objax.zoo.resnet_v2 import ResNet18
from jax import numpy as jnp, device_count
from tqdm import tqdm


if __name__ == "__main__":
    print(f"Num devices: {device_count()}")
    model = ResNet18(3, 1)
    opt = objax.optimizer.SGD(model.vars())

    @objax.Function.with_vars(model.vars())
    def loss(x, label):
        return objax.functional.loss.mean_squared_error(
            model(x, training=True), label
        ).mean()

    gv = objax.GradValues(loss, model.vars())

    train_vars = model.vars() + gv.vars() + opt.vars()

    @objax.Function.with_vars(train_vars)
    def train_op(
        image_batch,
        label_batch,
    ):

        grads, loss = gv(image_batch, label_batch)
        # grads = objax.functional.parallel.pmean(grads) # this line
        # loss = objax.functional.parallel.pmean(loss) # and this line
        loss = loss[0]
        opt(1e-3, grads)
        return loss, grads

    train_op = objax.Parallel(train_op, reduce=jnp.mean, vc=train_vars)

    with (train_vars).replicate():
        for _ in tqdm(range(10), total=10):
            data = jnp.array(np.random.randn(512, 3, 224, 224))
            label = jnp.zeros((512, 1))
            loss, grads = train_op(data, label)

Whenever you comment in the two lines with pmean the program gets stuck. However, if I understood it correctly, this is necessary to get the average of the gradients over all cards.

Error traces

As with most deadlock bugs you don't get an error stack trace. However, I have two clues that I've found so far. One is that if this is uncommented, the following appears:

2022-08-22 14:55:46.462557: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:31] This thread has been waiting for 10 seconds and may be stuck:
2022-08-22 14:55:48.543291: E external/org_tensorflow/tensorflow/compiler/xla/service/rendezvous.cc:36] Thread is unstuck! Warning above was a false-positive. Perhaps the timeout is too short.

The other is that if I manually interrupt it with ctrl+c I got this lengthy stacktrace

Setup

We use 4 NVIDIA A40 GPUs with CUDA Version 11.7 (Driver Version 515.65.01), cudnn 8.2.1.32, jax version 0.3.15, objax version 1.6.0

@a1302z
Copy link
Author

a1302z commented Aug 24, 2022

Quick update from my side: I've found a workaround that does not get stuck with more than two GPUs. However, it is extremely slow, much slower compared to a single GPU, probably caused by the repeated replicates.

import objax
import numpy as np
from objax.zoo.resnet_v2 import ResNet18
from jax import numpy as jnp, device_count
from tqdm import tqdm
from functools import partial


if __name__ == "__main__":
    N_devices = device_count()
    print(f"Num devices: {N_devices}")
    model = ResNet18(3, 1)
    opt = objax.optimizer.SGD(model.vars())

    @objax.Function.with_vars(model.vars())
    def loss_fn(x, label):
        return objax.functional.loss.mean_squared_error(
            model(x, training=True), label
        ).mean()

    gv = objax.GradValues(loss_fn, model.vars())

    train_vars = model.vars() + gv.vars() + opt.vars()

    @objax.Function.with_vars(train_vars)
    def train_op(
        image_batch,
        label_batch,
    ):

        grads, loss = gv(image_batch, label_batch)
        # grads = objax.functional.parallel.pmean(grads)
        # loss = objax.functional.parallel.pmean(loss)
        loss = loss[0]
        return loss, grads

    train_op = objax.Parallel(
        train_op,
        reduce=partial(jnp.mean, axis=0),
        vc=train_vars,
    )

    @objax.Function.with_vars(train_vars)
    def train_op_op(grads):
        opt(1e-3, grads)

    train_op_op = objax.Jit(train_op_op, vc=train_vars)

    data = jnp.array(np.random.randn(64, 3, 224, 224))
    label = jnp.zeros((64, 1))
    for _ in tqdm(range(10), total=10):
        with (train_vars).replicate():
            _, grads = train_op(data, label)
        train_op_op(grads)

@AlexeyKurakin
Copy link
Member

I am not entirely sure what could be the issue.
I did recently run Imagenet training code with pmean on a few v100 GPUs on a single machine without a problem.

It sounds like it some kind of bug of all-reduce / pmean.

I could only suggest either try different software/hardware configuration or try to reproduce this bug in pure JAX and report it to JAX team.

In pure JAX it could be something like the following:

def get_local_devices():
    x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
    sharded_x = map_to_device(x)
    return[b.device() for b in sharded_x.device_buffers]

local_devices = get_local_devices()


def loss_fn(params, x, y):
  per_example_loss = jnp.square(
      jnp.squeeze(jnp.dot(x, params['w']) + params['b']) - y)
  return jnp.mean(per_example_loss)

weights = {
      'w': jnp.ones((3, 1), dtype=jnp.float32),
      'b': jnp.array((5.), dtype=jnp.float32)
}

# set some sample x and y
# x should have shape [ndevices, per_device_batch, 3]
# y should have shape [ndevices, per_device_batch]

gv = jax.value_and_grad(loss_fn)

def train_op(params, x, y):
    _, g = gv(params, x, y)
    return jax.lax.pmean(g)

train_op_parallel = jax.pmap(train_op)

replicated_weights = jax.tree_util.tree_map(
    lambda x: jax.device_put_replicated(x, local_devices),
    weights)

train_op_parallel(replicated_weights, x, y)

This code roughly resembles the way Objax translates parallel training code into JAX.

@a1302z
Copy link
Author

a1302z commented Sep 21, 2022

You are right; it indeed seems to be a driver issue. We've now booked some paperspace instances, and it works perfectly fine on those. If we find the exact reason, we'll post it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants