diff --git a/compiler_opt/tools/combine_tfa_policies_lib_test.py b/compiler_opt/tools/combine_tfa_policies_lib_test.py index e0f46f2e..89a74eef 100644 --- a/compiler_opt/tools/combine_tfa_policies_lib_test.py +++ b/compiler_opt/tools/combine_tfa_policies_lib_test.py @@ -30,9 +30,7 @@ class AddOnePolicy(tf_agents.policies.TFPolicy): """Test policy which adds one to obs feature.""" def __init__(self): - obs_spec = { - 'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) - } + obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)} time_step_spec = ts.time_step_spec(obs_spec) act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) @@ -56,9 +54,7 @@ class SubtractOnePolicy(tf_agents.policies.TFPolicy): """Test policy which subtracts one to obs feature.""" def __init__(self): - obs_spec = { - 'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) - } + obs_spec = {'obs': tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)} time_step_spec = ts.time_step_spec(obs_spec) act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64)