From 421e51fffe4195cdc111aabe49ed5c2985337db5 Mon Sep 17 00:00:00 2001 From: Seleucia Date: Tue, 17 Oct 2017 00:07:39 +0200 Subject: [PATCH 1/2] Bug fix Selected idx will come from the gt check --- src/translate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/translate.py b/src/translate.py index 5bd1e91..58d850f 100644 --- a/src/translate.py +++ b/src/translate.py @@ -217,13 +217,14 @@ def train(): # (next 3 entries) are also not considered in the error, so the_key # are set to zero. # See https://github.com/asheshjain399/RNNexp/issues/6#issuecomment-249404882 - eulerchannels_pred[:,0:6] = 0 + gt_i=np.copy(srnn_gts_euler[action][i]) + gt_i[:,0:6] = 0 # Now compute the l2 error. The following is numpy port of the error # function provided by Ashesh Jain (in matlab), available at # https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/dataParser/Utils/motionGenerationError.m#L40-L54 - idx_to_use = np.where( np.std( eulerchannels_pred, 0 ) > 1e-4 )[0] - + idx_to_use = np.where( np.std( gt_i, 0 ) > 1e-4 )[0] + euc_error = np.power( srnn_gts_euler[action][i][:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2) euc_error = np.sum(euc_error, 1) euc_error = np.sqrt( euc_error ) From 510170df1c00358d640b50036fbe3aee04fa5bcb Mon Sep 17 00:00:00 2001 From: Seleucia Date: Tue, 17 Oct 2017 00:09:41 +0200 Subject: [PATCH 2/2] Update translate.py --- src/translate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/translate.py b/src/translate.py index 58d850f..3e3e9ad 100644 --- a/src/translate.py +++ b/src/translate.py @@ -225,7 +225,7 @@ def train(): # https://github.com/asheshjain399/RNNexp/blob/srnn/structural_rnn/CRFProblems/H3.6m/dataParser/Utils/motionGenerationError.m#L40-L54 idx_to_use = np.where( np.std( gt_i, 0 ) > 1e-4 )[0] - euc_error = np.power( srnn_gts_euler[action][i][:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2) + euc_error = np.power( gt_i[:,idx_to_use] - eulerchannels_pred[:,idx_to_use], 2) euc_error = np.sum(euc_error, 1) euc_error = np.sqrt( euc_error ) mean_errors[i,:] = euc_error