Skip to content

Commit

Permalink
ready for experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahBarrett98 committed Mar 8, 2021
1 parent 7b51213 commit ac050cc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
15 changes: 13 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,36 @@ In this first and thrid video, you can see the robot learning how to interact wi
Below there are four videos please fill out your responses in the survey below each video. Thanks for participating!

## Algorithm 1 (Learning)
link youtube vid here ...

<iframe width="420" height="315"
src="https://www.youtube.com/embed/tgbNymZ7vqY">
</iframe>


<iframe src="https://docs.google.com/forms/d/e/1FAIpQLSfcW9IFqY0mSJinQuvx8mZV1KMFScGUH_wVxpi6KV4SIMtUVw/viewform?embedded=true" width="640" height="890" frameborder="0" marginheight="0" marginwidth="0">Loading…</iframe>

## Algorithm 1

<iframe width="420" height="315"
src="https://www.youtube.com/embed/tgbNymZ7vqY">
</iframe>

<iframe src="https://docs.google.com/forms/d/e/1FAIpQLSfeYxbYgFo_Dyy5n0Ap_MxWxOJbfXomKOAY5SuRk9BhM2fafA/viewform?embedded=true" width="640" height="390" frameborder="0" marginheight="0" marginwidth="0">Loading…</iframe>

## Algorithm 2 (Learning)

<iframe width="640" height="1057"
<iframe width="420" height="315"
src="https://www.youtube.com/embed/tgbNymZ7vqY">
</iframe>

<iframe src="https://docs.google.com/forms/d/e/1FAIpQLSf1ESyx6VQF0gWw-EU1O__RvmL7Gak7DjO1yrS67ZFGFokp4w/viewform?embedded=true" width="640" height="1057" frameborder="0" marginheight="0" marginwidth="0">Loading…</iframe>

## Algorithm 2

<iframe width="420" height="315"
src="https://www.youtube.com/embed/tgbNymZ7vqY">
</iframe>


<iframe src="https://docs.google.com/forms/d/e/1FAIpQLSfR9pqg4Wj6w175omNOWPtDBD2C81l8kkPNz6zcQwWXKTzvJA/viewform?embedded=true" width="640" height="1057" frameborder="0" marginheight="0" marginwidth="0">Loading…</iframe>

6 changes: 4 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
IS_PI = True
# from src.DRL.run_session import train_and_test_bot
from src.DRL.train import train
from src.DRL.test import test
elif hostname == settings.PC:
IS_PI = False
# from src.MainNode.Utils.OverHead import OverHead
# from src.MainNode.Utils.OverHead import OverHead
# from src.MainNode.Utils.stream_oh_inference import run
else:
raise Exception("Must configure this device in settings.py")
Expand All @@ -30,8 +31,9 @@ def main_train_and_test():
else:
pibot = PiBot2()
steps = 200
train.train_session(pibot, steps)
train.train_session(pibot, steps, continue_prompt=True)
# publish_data(actions)
test.run_policy(pibot, steps, continue_prompt=True)



Expand Down
20 changes: 9 additions & 11 deletions src/DRL/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@
with open(POLICYF, "r") as f:
POLICIES = json.load(f)

def run_policy(num_steps,pibot, policy_dict=POLICIES) -> str:
def run_policy(pibot,num_steps, policy_dict=POLICIES, continue_prompt=False) -> None:
# If the environment don't follow the interface, an error will be thrown
env = PiBotEnv2(pibot)
# record actions
env._RECORD_ACTION = True
check_env(env, warn=True)
# The algorithms require a vectorized environment to run
env = DummyVecEnv([partial(PiBotEnv2, PiBot=pibot)])
results = {"DEVICE_NAME":HOSTNAME,
"POLICIES":[]}
d_env = DummyVecEnv([partial(PiBotEnv2, PiBot=pibot)])
for k, v in policy_dict.items():
results["POLICIES"].append(k)
if "PPO2" in k:
model = PPO2.load(v)
if "A2C" in k:
model = A2C.load(v)
obs = env.reset()
results[k] = []
obs = d_env.reset()
for i in range(num_steps):
action, _states = model.predict(obs)
results[k].append(list(list(int(i) for i in a) for a in action))
obs, rewards, done, info = env.step(action)
print(results)
return json.dumps(results)
env._record_actions(k+"_INFERENCE")
if continue_prompt:
input("hit enter to continue: ")
8 changes: 5 additions & 3 deletions src/DRL/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@
#"ACER": partial(ACER, policy=MlpPolicy, verbose=1),
# "ACKTR": partial(ACKTR, policy=MlpPolicy, verbose=1),
# "DQN": partial(DQN, policy=MlpPolicy, verbose=1),
# "PPO2": partial(PPO2, policy=MlpPolicy, verbose=1),
"PPO2": partial(PPO2, policy=MlpPolicy, verbose=1),
"A2C": partial(A2C, policy=MlpPolicy, verbose=1),


}

def train_session(pibot, steps, train_dict=TRAIN_DICT):
def train_session(pibot, steps, train_dict=TRAIN_DICT, continue_prompt=False):
"""
wrap train dict with training function
"""
for model, func in train_dict.items():
print(f"TRAINING: {model}")
train(steps, pibot, func, model)
if continue_prompt:
input("hit enter to continue: ")

def train(steps, pibot, model, model_name) -> str:
# If the environment don't follow the interface, an error will be thrown
Expand All @@ -46,6 +48,6 @@ def train(steps, pibot, model, model_name) -> str:
policies[model_name] = fpath
with open(POLICYF, "w") as f:
json.dump(policies, f)
env._record_actions(model_name)
env._record_actions(model_name+"_TRAINING")

return fpath

0 comments on commit ac050cc

Please sign in to comment.