Skip to content

Commit

Permalink
Remove test code
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaevIlya committed Aug 13, 2024
1 parent 75d73ed commit 4829fa7
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions s3torchbenchmarking/src/s3torchbenchmarking/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,43 +279,3 @@ def save_checkpoint_to_disk(model: nn.Module, uri: str, batch_idx: int):
save_time = end_time - start_time
print(f"Saving checkpoint to {path} took {save_time} seconds")
return save_time


def run():
model = LightningAdapter(
sample_transformer=Transforms.transform_image,
model=ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k", num_labels=1024
),
config=DictConfig(
content={
"checkpoint": {
"destination": "disk",
"uri": "/tmp/checkpoints/",
"save_one_in": 2,
# "destination": "s3",
# "uri": "s3://swift-benchmark-dataset/checkpoints/",
"region": "eu-west-2",
}
}
),
)
# model = ViT(1024, DictConfig({
# "destination": "disk",
# "save_one_in": 1,
# "uri": "checkpoints/",
# "region": "eu-west-2",
# }))
dataset = S3IterableDataset.from_prefix(
"s3://swift-benchmark-dataset/4_images/", region="eu-west-2"
)
dataset = torchdata.datapipes.iter.IterableWrapper(dataset)
dataset = dataset.map(model.load_sample)
dataset = dataset.sharding_filter()
dataloader = DataLoader(dataset=dataset, num_workers=8)
result = model.train(dataloader=dataloader, epochs=1)
print(f"print!s: {result=!s}")


if __name__ == "__main__":
run()

0 comments on commit 4829fa7

Please sign in to comment.