Skip to content

Commit

Permalink
[fix-daily] Use composer get_model_state_dict instead of torch's
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Jul 25, 2024
1 parent c5542a3 commit 55f0b7d
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/checkpoint/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
import torch

from composer.checkpoint import download_monolithic_checkpoint
from composer.checkpoint import download_monolithic_checkpoint, get_model_state_dict
from composer.utils import dist
from tests.checkpoint.helpers import init_model
from tests.common.markers import world_size
Expand All @@ -26,8 +26,7 @@ def test_download_monolithic_checkpoint(world_size: int, rank_zero_only: bool):
use_fsdp = True
fsdp_model, _ = init_model(use_fsdp=use_fsdp)

from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
state = get_model_state_dict(fsdp_model, options=StateDictOptions(full_state_dict=True))
state = get_model_state_dict(fsdp_model, sharded_state_dict=False)

checkpoint_filename = 'state_dict'
save_filename = os.path.join(tmp_dir.name, checkpoint_filename)
Expand Down

0 comments on commit 55f0b7d

Please sign in to comment.