diff --git a/s3torchbenchmarking/src/s3torchbenchmarking/models.py b/s3torchbenchmarking/src/s3torchbenchmarking/models.py index 94e62a56..9520eada 100644 --- a/s3torchbenchmarking/src/s3torchbenchmarking/models.py +++ b/s3torchbenchmarking/src/s3torchbenchmarking/models.py @@ -279,43 +279,3 @@ def save_checkpoint_to_disk(model: nn.Module, uri: str, batch_idx: int): save_time = end_time - start_time print(f"Saving checkpoint to {path} took {save_time} seconds") return save_time - - -def run(): - model = LightningAdapter( - sample_transformer=Transforms.transform_image, - model=ViTForImageClassification.from_pretrained( - "google/vit-base-patch16-224-in21k", num_labels=1024 - ), - config=DictConfig( - content={ - "checkpoint": { - "destination": "disk", - "uri": "/tmp/checkpoints/", - "save_one_in": 2, - # "destination": "s3", - # "uri": "s3://swift-benchmark-dataset/checkpoints/", - "region": "eu-west-2", - } - } - ), - ) - # model = ViT(1024, DictConfig({ - # "destination": "disk", - # "save_one_in": 1, - # "uri": "checkpoints/", - # "region": "eu-west-2", - # })) - dataset = S3IterableDataset.from_prefix( - "s3://swift-benchmark-dataset/4_images/", region="eu-west-2" - ) - dataset = torchdata.datapipes.iter.IterableWrapper(dataset) - dataset = dataset.map(model.load_sample) - dataset = dataset.sharding_filter() - dataloader = DataLoader(dataset=dataset, num_workers=8) - result = model.train(dataloader=dataloader, epochs=1) - print(f"print!s: {result=!s}") - - -if __name__ == "__main__": - run()