Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix] fix parameter shape checking #6124

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
if length > current_shape[dim]:
partition_dim = dim
break
if partition_dim is not None:
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_shape = [128256, 4096]
current_shape = [16064, 4096]
tp_size = 8

16064 * 8 = 128512 != 128256

When I delete 114~118, the size of model.embed_tokens.weight is [16064, 4096] in the saved model .

), f"The parameter isn't evenly distributed among tensor parallel group: \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_shape should be [16032, 4096]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_shape refers to the shape before padding, while tp_size * current_shape refers to the shape after padding. Therefore, the assert will throw an error.

shape before sharding {original_shape}, shape after sharding {current_shape}"

return partition_dim

Expand Down
Loading