Skip to content

Commit

Permalink
Allow passing in lbl_process_group directly (#1298)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jun 21, 2024
1 parent 129e3e1 commit 2196d07
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/utils/config_moe_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def config_megablocks_moe_args(
lbl_process_group = create_set_process_group(lbl_process_group)
else:
lbl_process_group = None
elif lbl_process_group is not None:
elif not isinstance(lbl_process_group, distributed.ProcessGroup):
raise ValueError(
f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | <GROUP_SIZE>.',
f'Unknown {lbl_process_group=}. Options are: none | a process group | ``expert_group`` | ``global_group`` | <GROUP_SIZE>.',
)
ffn_config['lbl_process_group'] = lbl_process_group

Expand Down
30 changes: 30 additions & 0 deletions tests/models/utils/test_config_moe_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any

import pytest

from llmfoundry.models.utils.config_moe_args import (
config_megablocks_moe_args,
get_megablocks_device_mesh,
)


@pytest.mark.gpu
def test_config_megablocks_moe_args_error():
ffn_config_base: dict[str, Any] = {
'moe_world_size': 1,
'lbl_process_group': 'not_real',
'ffn_type': 'mb_moe',
'fc_type': 'torch',
}

with pytest.raises(ValueError):
config_megablocks_moe_args(
ffn_config=ffn_config_base,
d_model=128,
expansion_ratio=4,
n_layers=2,
get_device_mesh=get_megablocks_device_mesh,
)

0 comments on commit 2196d07

Please sign in to comment.