Skip to content

Commit

Permalink
style: 💄 lint
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Sep 26, 2024
1 parent 19be9b7 commit 5efbdea
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
23 changes: 12 additions & 11 deletions experiments/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import os
import sys
import time
sys.path.append(os.path.abspath(__file__ + '/../..'))
from argparse import ArgumentParser

import basicts

sys.path.append(os.path.abspath(__file__ + '/../..'))
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import basicts

def parse_args():
parser = ArgumentParser(description="Evaluate time series forecasting model in BasicTS framework!")
parser.add_argument("-cfg", "--config", default="baselines/STID/PEMS08_LTSF.py", help="training config") # enter your config file path
parser.add_argument("-ckpt", "--checkpoint", default="checkpoints/STID/PEMS08_100_336_336/97d131cadc14bd2b9ffa892d59d55129/STID_best_val_MAE.pt") # enter your own checkpoint file path
parser.add_argument("-g", "--gpus", default="5")
parser.add_argument("-d", "--device_type", default="gpu")
parser.add_argument("-b", "--batch_size", default=None) # use the batch size in the config file

parser = ArgumentParser(description='Evaluate time series forecasting model in BasicTS framework!')
# enter your config file path
parser.add_argument('-cfg', '--config', default='baselines/STID/PEMS08_LTSF.py', help='training config')
# enter your own checkpoint file path
parser.add_argument('-ckpt', '--checkpoint', default='checkpoints/STID/PEMS08_100_336_336/97d131cadc14bd2b9ffa892d59d55129/STID_best_val_MAE.pt')
parser.add_argument('-g', '--gpus', default='5')
parser.add_argument('-d', '--device_type', default='gpu')
parser.add_argument('-b', '--batch_size', default=None) # use the batch size in the config file

return parser.parse_args()

if __name__ == '__main__':
Expand Down
17 changes: 8 additions & 9 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
# Run a baseline model in BasicTS framework.


import os
import sys
from argparse import ArgumentParser

# TODO: remove it when basicts can be installed by pip
sys.path.append(os.path.abspath(__file__ + "/../.."))
sys.path.append(os.path.abspath(__file__ + '/../..'))
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import torch

import basicts

torch.set_num_threads(4) # aviod high cpu avg usage
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def parse_args():
parser = ArgumentParser(description="Run time series forecasting model in BasicTS framework!")
parser.add_argument("-c", "--cfg", default="baselines/STID/PEMS04.py", help="training config")
parser.add_argument("-g", "--gpus", default="0", help="visible gpus")
parser = ArgumentParser(description='Run time series forecasting model in BasicTS framework!')
parser.add_argument('-c', '--cfg', default='baselines/STID/PEMS04.py', help='training config')
parser.add_argument('-g', '--gpus', default='0', help='visible gpus')
return parser.parse_args()

if __name__ == "__main__":
if __name__ == '__main__':
args = parse_args()

basicts.launch_training(args.cfg, args.gpus, node_rank=0)

0 comments on commit 5efbdea

Please sign in to comment.