Skip to content

Commit

Permalink
Merge pull request #17 from Seleucia/master
Browse files Browse the repository at this point in the history
Fix #9 - the indices used to evaluate error are taken from ground truth.
  • Loading branch information
una-dinosauria authored Oct 20, 2017
2 parents 1b39a73 + 510170d commit c9a2774
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,15 @@ def train():
# (next 3 entries) are also not considered in the error, so the_key
# are set to zero.
# See https://github.com/asheshjain399/RNNexp/issues/6#issuecomment-249404882
eulerchannels_pred[:,0:6] = 0
gt_i=np.copy(srnn_gts_euler[action][i])
gt_i[:,0:6] = 0

# Now compute the l2 error. The following is numpy port of the error
# function provided by Ashesh Jain (in matlab), available at
# https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/dataParser/Utils/motionGenerationError.m#L40-L54
idx_to_use = np.where( np.std( eulerchannels_pred, 0 ) > 1e-4 )[0]

euc_error = np.power( srnn_gts_euler[action][i][:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2)
idx_to_use = np.where( np.std( gt_i, 0 ) > 1e-4 )[0]
euc_error = np.power( gt_i[:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2)
euc_error = np.sum(euc_error, 1)
euc_error = np.sqrt( euc_error )
mean_errors[i,:] = euc_error
Expand Down

0 comments on commit c9a2774

Please sign in to comment.