diff --git a/src/translate.py b/src/translate.py index 5bd1e91..3e3e9ad 100644 --- a/src/translate.py +++ b/src/translate.py @@ -217,14 +217,15 @@ def train(): # (next 3 entries) are also not considered in the error, so the_key # are set to zero. # See https://github.com/asheshjain399/RNNexp/issues/6#issuecomment-249404882 - eulerchannels_pred[:,0:6] = 0 + gt_i=np.copy(srnn_gts_euler[action][i]) + gt_i[:,0:6] = 0 # Now compute the l2 error. The following is numpy port of the error # function provided by Ashesh Jain (in matlab), available at # https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/dataParser/Utils/motionGenerationError.m#L40-L54 - idx_to_use = np.where( np.std( eulerchannels_pred, 0 ) > 1e-4 )[0] - - euc_error = np.power( srnn_gts_euler[action][i][:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2) + idx_to_use = np.where( np.std( gt_i, 0 ) > 1e-4 )[0] + + euc_error = np.power( gt_i[:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2) euc_error = np.sum(euc_error, 1) euc_error = np.sqrt( euc_error ) mean_errors[i,:] = euc_error