Skip to content

Commit

Permalink
better autobatching style
Browse files Browse the repository at this point in the history
  • Loading branch information
redpony committed Nov 7, 2017
1 parent c57f9da commit fc5ffc2
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions examples/mnist/mnist-autobatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# To run this, download the four files from http://yann.lecun.com/exdb/mnist/
# and gunzip them into a single path. Pass this path to the program with the
# --path option. You will also want to run with --dynet_autobatch 1.
# To turn on GPU training, run with --dynet-gpus 1.

parser = argparse.ArgumentParser()
parser.add_argument("--path", default=".",
Expand All @@ -18,6 +19,8 @@
parser.add_argument("--conv", dest="conv", action="store_true")
parser.add_argument("--dynet_autobatch", default=0,
help="Set to 1 to turn on autobatching.")
parser.add_argument("--dynet-gpus", default=0,
help="Set to 1 to train on GPU.")

HIDDEN_DIM = 1024
DROPOUT_RATE = 0.4
Expand Down Expand Up @@ -118,13 +121,18 @@ def __call__(self, x, dropout=False):
loss = dy.pickneglogsoftmax(logits, lbl)
losses.append(loss)
mbloss = dy.esum(losses) / mbsize
mbloss.backward()
sgd.update()

# eloss is an exponentially smoothed loss.
if eloss is None:
eloss = mbloss.scalar_value()
else:
eloss = mbloss.scalar_value() * alpha + eloss * (1.0 - alpha)
mbloss.backward()
sgd.update()

# Do dev evaluation here:
if (i > 0) and (i % dev_report == 0):
confusion = [[0 for _ in xrange(10)] for _ in xrange(10)]
correct = 0
dev_start = time.time()
for s in range(0, len(testing), args.minibatch_size):
Expand All @@ -137,21 +145,26 @@ def __call__(self, x, dropout=False):
x = dy.inputVector(img)
logits = classify(x)
scores.append((lbl, logits))
# we want to evaluate all the logits in a batch, this is a hack
# to do this.
dummy = dy.esum([logits for _, logits in scores])
dummy.forward()

# This evaluates all the logits in a batch if autobatching is on.
dy.forward([logits for _, logits in scores])

# now we can retrieve the batch-computed logits cheaply
for lbl, logits in scores:
prediction = np.argmax(logits.npvalue())
if lbl == prediction:
correct += 1
confusion[prediction][lbl] += 1
dev_end = time.time()
acc = float(correct) / len(testing)
dev_time += dev_end - dev_start
print("Held out accuracy {} ({} instances/sec)".format(
acc, len(testing) / (dev_end - dev_start)))
print ' ' + ''.join(('T'+str(x)).ljust(6) for x in xrange(10))
for p, row in enumerate(confusion):
s = 'P' + str(p) + ' '
s += ''.join(str(col).ljust(6) for col in row)
print(s)

if (i > 0) and (i % report == 0):
print("moving avg loss: {}".format(eloss))
Expand Down

0 comments on commit fc5ffc2

Please sign in to comment.