Skip to content

Commit

Permalink
Runner script to sparify reward models for PointMaze
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGleave committed Feb 3, 2020
1 parent 89dcc75 commit aab2933
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions runners/sparsify_point_maze.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env bash
# Copyright 2020 Adam Gleave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script to sparsify pretrained reward models generated by `transfer_point_maze.sh`

DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
. ${DIR}/common.sh

ENV_TRAIN="imitation/PointMazeLeftVel-v0"
TRANSITION_P=0.05

if [[ ${fast} == "true" ]]; then
# intended for debugging
COMPARISON_TIMESTEPS="fast"
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze_fast
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze_fast
else
COMPARISON_TIMESTEPS=""
EVAL_TIMESTEPS=100000
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze
fi

MIXED_POLICY_PATH=${TRANSITION_P}:random:dummy:ppo2:${PM_OUTPUT}/expert/train/policies/final
for name in comparison_expert comparison_mixture comparison_random; do
if [[ ${name} == "comparison_expert" ]]; then
extra_flags="dataset_factory_kwargs.policy_type=ppo2 \
dataset_factory_kwargs.policy_path=${PM_OUTPUT}/expert/train/policies/final"
elif [[ ${name} == "comparison_mixture" ]]; then
extra_flags="dataset_factory_kwargs.policy_type=mixture \
dataset_factory_kwargs.policy_path=${MIXED_POLICY_PATH}"
elif [[ ${name} == "comparison_random" ]]; then
extra_flags=""
else
echo "BUG: unknown name ${name}"
exit 1
fi
parallel --header : --results ${SPARSE_OUTPUT}/parallel/${name} \
$(call_script "model_comparison" "with") \
env_name=${ENV_TRAIN} ${extra_flags} \
ellp_loss no_rescale target_reward_type=evaluating_rewards/Zero-v0 \
seed={seed} source_reward_type={source_reward_type} \
source_reward_path=${PM_OUTPUT}/reward/{source_reward_path}/{source_reward_suffix} \
${COMPARISON_TIMESTEPS} log_dir=${SPARSE_OUTPUT}/${name}/{source_reward_path}/{seed} \
::: source_reward_type evaluating_rewards/PointMazeGroundTruthWithCtrl-v0 \
evaluating_rewards/PointMazeGroundTruthNoCtrl-v0 \
evaluating_rewards/RewardModel-v0 evaluating_rewards/RewardModel-v0 \
imitation/RewardNet_unshaped-v0 imitation/RewardNet_unshaped-v0 \
:::+ source_reward_path withctrl noctrl preferences regress irl_state_only irl_state_action \
:::+ source_reward_suffix dummy dummy model model checkpoints/final/discrim/reward_net \
checkpoints/final/discrim/reward_net \
::: seed 0 1 2
done

0 comments on commit aab2933

Please sign in to comment.