Skip to content

Commit

Permalink
Q-learn works?
Browse files Browse the repository at this point in the history
  • Loading branch information
roznawsk committed Aug 20, 2024
1 parent 3e74d96 commit 6af0e34
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 50 deletions.
16 changes: 8 additions & 8 deletions infant_abm/agents/infant/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class InteractWithToy(Action):
number = 3


# class EvaluateToy(Action):
# def __init__(self, duration=0, metadata=None):
# super().__init__(metadata)
# self.duration = duration
class EvaluateToy(Action):
def __init__(self, duration=0, metadata=None):
super().__init__(metadata)
self.duration = duration


# class EvaluateThrow(Action):
# def __init__(self, duration=0, metadata=None):
# super().__init__(metadata)
# self.duration = duration
class EvaluateThrow(Action):
def __init__(self, duration=0, metadata=None):
super().__init__(metadata)
self.duration = duration
61 changes: 34 additions & 27 deletions infant_abm/agents/infant/q_learning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,71 @@

# [infant_looked_at_toy, parent_looked_at_toy, mutual_gaze]
STATE_SPACE = np.array([2, 2, 2])
STATE_SPACE_SIZE = np.mul.reduce(STATE_SPACE)
STATE_SPACE_SIZE = np.multiply.reduce(STATE_SPACE)
GOAL_STATE = np.array([1, 1, 1])


class QLearningAgent:
def __init__(self, model, actions, alpha=0.1, gamma=0.9, epsilon=0.1):
self.model = model
self.q_table = np.zeros((STATE_SPACE_SIZE, len(actions)))
self.q_table = np.random.rand(STATE_SPACE_SIZE, len(actions))

print(f"init\n\n{self.q_table}")

self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.actions = {action: num for num, action in enumerate(actions)}
self.number_actions = {num: action for action, num in self.actions.items()}

def choose_action(self, infant):
state = self._get_state(infant)
def choose_action(self):
state = self.get_state()

if np.random.rand() < self.epsilon:
return np.random.choice(self.actions.keys()) # Explore
return np.random.choice(list(self.actions.keys())) # Explore
else:
# print(
# f"argmax {np.argmax(self.q_table[state])}\n num_actions {self.number_actions}"
# )
return self.number_actions[np.argmax(self.q_table[state])] # Exploit

def update_q_table(self, state, action, reward, next_state):
state = self._to_number_state(state)
next_state = QLearningAgent._to_number_state(next_state)

action = self.actions[action]
best_next_action = np.max(self.q_table[next_state])
self.q_table[state, action] += self.alpha * (

# print(
# f"state {state}, action {action}, reward {reward}, best_next {best_next_action}, next_state {next_state}"
# )
# print(f"{self.q_table[state, action]}\n")

plus = self.alpha * (
reward + self.gamma * best_next_action - self.q_table[state, action]
)

# print(plus)

self.q_table[state, action] += plus

def get_state(self):
# [infant_looked_at_toy, parent_looked_at_toy, mutual_gaze]

return (
int(self._infant_looked_at_toy()),
int(self._parent_looked_at_toy_after_infant()),
int(),
raw_state = np.array(
[
int(self._infant_looked_at_toy()),
int(self._parent_looked_at_toy_after_infant()),
int(self._mutual_gaze()),
]
)

multiplier = np.array([4, 2, 1])
return np.sum(raw_state * multiplier)

def _infant_looked_at_toy(self):
return any([isinstance(obj, Toy) for obj in self.model.infant.gaze_directions])

def _parent_looked_at_toy_after_infant(self):
for i, obj in enumerate(self.model.infant.gaze_directions):
if obj in self.model.parent.gaze_directions[i:]:
if isinstance(obj, Toy) and obj in self.model.parent.gaze_directions[i:]:
return True

return False
Expand All @@ -61,19 +80,7 @@ def _mutual_gaze(self):
)

def reward(self, state):
if state == GOAL_STATE:
if np.all(state == GOAL_STATE):
return 1
else:
return 0

@staticmethod
def _to_number_state(state):
(previous_action, infant_gaze, parent_gaze) = state
return parent_gaze + infant_gaze * 6 + previous_action * 6 * 6

@staticmethod
def _get_agent_gaze_direction(agent):
if agent.gaze_direction is None:
return 0
else:
return min(agent.gaze_direction.unique_id, 5)
33 changes: 24 additions & 9 deletions infant_abm/agents/infant/seq_vision_infant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class SeqVisionInfant(InfantBase):
# TOY_EVALUATION_DURATION = 3
# THROW_EVALUATION_DURATION = 20

# PERSISTENCE_BOOST_DURATION = 20
PERSISTENCE_BOOST_DURATION = 20

# COORDINATION_BOOST_VALUE = 0.2
# PERSISTENCE_BOOST_VALUE = 0.2
Expand All @@ -35,22 +35,36 @@ def __init__(self, unique_id, model, pos, params: Params):
self.current_persistence_boost_duration = 0
self.q_learning_state = None

def get_actions(self):
return [actions.LookForToy, actions.Crawl, actions.InteractWithToy]
def get_q_actions(self):
return [None, self.model.parent] + self.model.get_toys()

def before_step(self):
self.q_learning_state = self.model.q_learning_agent.get_state(self)
def _before_step(self):
self.q_learning_state = self.model.q_learning_agent.get_state()

self.gaze_directions.append(self.model.q_learning_agent.choose_action(self))
new_action = self.model.q_learning_agent.choose_action()

self.gaze_directions.append(new_action)
self.gaze_directions.pop(0)

def after_step(self):
next_state = self.model.q_learning_agent.get_state(self)
next_state = self.model.q_learning_agent.get_state()
reward = self.model.q_learning_agent.reward(next_state)
self.model.q_learning_agent.update_q_table(
self.q_learning_state, self.gaze_direction, reward, next_state
self.q_learning_state, self.gaze_directions[-1], reward, next_state
)

# print(f"{self.gaze_directions[-2:]}, {self.model.parent.gaze_directions[-2:]}")

if np.random.rand() < 0.005:
# print(next_state)
# print(self.model.q_learning_agent.q_table)
print(
{
state: np.argmax(self.model.q_learning_agent.q_table[state])
for state in range(8)
}
)

def _step_look_for_toy(self, _action):
self.current_persistence_boost_duration = 0
self.params.persistence.reset()
Expand Down Expand Up @@ -117,7 +131,8 @@ def _step_interact_with_toy(self, _action):
def _step_crawl(self, _action):
if self._target_in_range():
self._start_evaluating_throw()
return actions.EvaluateThrow()
# return actions.EvaluateThrow()
return actions.InteractWithToy()

if self._gets_distracted():
self.target = None
Expand Down
2 changes: 1 addition & 1 deletion infant_abm/agents/parent/vision_only_parent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _random_gaze_direction(self):
case 1:
return self.model.infant
case 2:
if target is not None and 0.5 > np.random.rand:
if target is not None and 0.5 > np.random.rand():
return target
else:
toys = self.model.get_toys()
Expand Down
2 changes: 2 additions & 0 deletions infant_abm/agents/parent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, unique_id, model, pos):
def step(self):
self.satisfaction = 0

self._before_step()

match self.next_action:
case Action.WAIT:
pass
Expand Down
7 changes: 4 additions & 3 deletions infant_abm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ def __init__(
self.make_agents(infant_params)

self.q_learning_agent = QLearningAgent(
model=self, actions=self.infant.get_actions()
model=self, actions=self.infant.get_q_actions()
)
self.infant.q_learning_state = self.q_learning_agent.get_state()

self.datacollector = mesa.DataCollector(
model_reporters={
"parent-visible": lambda m: int(getattr(m.infant, "parent_visible", 0)),
"infant-visible": lambda m: int(m.parent.infant_visible) / 2,
"heading": lambda m: m.infant.params.persistence.e2,
"throwing": lambda m: m.infant.params.coordination.e2,
"goal_dist": lambda m: m.get_middle_dist(),
Expand All @@ -115,6 +114,8 @@ def step(self):

self.datacollector.collect(self)

self.infant.after_step()

def get_middle_dist(self) -> float:
middle_point = (self.parent.pos + self.infant.pos) / 2

Expand Down
6 changes: 4 additions & 2 deletions run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ def run_comparative_boost_simulation():
if __name__ == "__main__":
output_dir = "./results/model2.0/q-learn"

params = [{"infant_params": InfantParams.from_array(0.5, 0.5, 0.5)}]
params = [
{"infant_params": InfantParams.from_array([0.5, 0.5, 0.5]), "config": Config()}
]

run_basic_simulation(
output_dir=output_dir,
parameter_sets=params,
iterations=10000,
iterations=5000,
repeats=1,
)

0 comments on commit 6af0e34

Please sign in to comment.