Skip to content

Commit

Permalink
Merge pull request #1160 from streamjoin/update-type-tensor
Browse files Browse the repository at this point in the history
Update the data types and tensor operations for training
  • Loading branch information
nudles authored Apr 25, 2024
2 parents ad4c0fa + 865ce0b commit d6f52ff
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/cnn_ms/train_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ def run(global_rank,
synflow_flag = True
### step 1: all one input
# Copy the patch data into input tensors
tx.copy_from_numpy(np.ones(x.shape))
tx.copy_from_numpy(np.ones(x.shape, dtype=np.float32))
ty.copy_from_numpy(y)
### step 2: all weights turned to positive (done)
### step 3: new loss (done)
pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars)
pn_p_g_list, out, loss = model(tx, ty,dist_option, spars, synflow_flag)
### step 4: calculate the multiplication of weights
synflow_score = 0.0
for pn_p_g_item in pn_p_g_list:
Expand All @@ -430,13 +430,13 @@ def run(global_rank,
# Copy the patch data into input tensors
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars)
pn_p_g_list, out, loss = model(tx, ty, dist_option, spars, synflow_flag)
train_correct += accuracy(tensor.to_numpy(out), y)
train_loss += tensor.to_numpy(loss)[0]
# all params turned to positive
for pn_p_g_item in pn_p_g_list:
print ("absolute value parameter name: \n", pn_p_g_item[0])
pn_p_g_item[1].data = tensor.abs(pn_p_g_item[1].data)
pn_p_g_item[1] = tensor.abs(pn_p_g_item[1]) # tensor variables
else: # normal train steps
# Copy the patch data into input tensors
tx.copy_from_numpy(x)
Expand Down Expand Up @@ -491,7 +491,7 @@ def run(global_rank,
description='Training using the autograd and graph.')
parser.add_argument(
'model',
choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'alexnet'],
choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'msmlp', 'alexnet'],
default='cnn')
parser.add_argument('data',
choices=['mnist', 'cifar10', 'cifar100'],
Expand Down

0 comments on commit d6f52ff

Please sign in to comment.