Skip to content

Commit

Permalink
Fix train.py for saving sample npy (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
seldauyanik-maxim authored Dec 15, 2022
1 parent 4966ad1 commit ec7b362
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 2 additions & 0 deletions scripts/evaluate_svhn_tinierssd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/sh
python train.py --deterministic --print-freq 200 --model ai85tinierssd --use-bias --dataset SVHN_74 --device MAX78000 --obj-detection --obj-detection-params parameters/obj_detection_params_svhn.yaml --qat-policy policies/qat_policy_svhn.yaml --evaluate -8 --exp-load-weights-from ../ai8x-synthesis/trained/ai85-svhn-tinierssd-qat8-q.pth.tar "$@"
9 changes: 7 additions & 2 deletions train.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def main():
print('WARNING: Initial learning rate (--lr) not set, selecting 0.1.')
args.lr = 0.1

if args.generate_sample is not None and not args.act_mode_8bit:
print('WARNING: Cannot save sample in training mode, ignoring --save-sample option. '
'Use with --evaluate instead.')

msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name,
args.output_dir)

Expand Down Expand Up @@ -1020,6 +1024,7 @@ def save_tensor(t, f, regression=True):
end = time.time()
class_probs = []
class_preds = []
sample_saved = False # Track if --save-sample has been done for this validation step

# Get object detection params
obj_detection_params = parse_obj_detection_yaml.parse(args.obj_detection_params) \
Expand Down Expand Up @@ -1091,9 +1096,9 @@ def save_tensor(t, f, regression=True):
and model.__dict__['_modules'][key].wide):
output /= 256.

if args.generate_sample is not None:
if args.generate_sample is not None and args.act_mode_8bit and not sample_saved:
sample.generate(args.generate_sample, inputs, target, output, args.dataset, False)
return .0, .0, .0, .0
sample_saved = True

if args.csv_prefix is not None:
save_tensor(inputs, f_x)
Expand Down

0 comments on commit ec7b362

Please sign in to comment.