Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add customisable observation #22

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions python/pandemic_simulator/environment/done.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Type, Union, Any

import numpy as np
from pandemic_simulator.environment.interfaces.sim_state import PandemicSimState

__all__ = ['DoneFunctionType', 'DoneFunctionFactory', 'ORDone', 'DoneFunction',
'InfectionSummaryAboveThresholdDone', 'NoMoreInfectionsDone', 'NoPandemicDone']
Expand All @@ -20,6 +21,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
def calculate_done(self, obs: PandemicObservation, action: int) -> bool:
pass

def calculate_done_from_state(self, state: PandemicSimState, action: int) -> bool:
pass

def reset(self) -> None:
pass

Expand Down Expand Up @@ -73,6 +77,9 @@ def __init__(self, done_fns: List[DoneFunction], *args: Any, **kwargs: Any):
def calculate_done(self, obs: PandemicObservation, action: int) -> bool:
return any([df.calculate_done(obs, action) for df in self._done_fns])

def calculate_done_from_state(self, state: PandemicSimState, action: int) -> bool:
return any([df.calculate_done_from_state(state, action) for df in self._done_fns])

def reset(self) -> None:
for done_fn in self._done_fns:
done_fn.reset()
Expand All @@ -92,6 +99,9 @@ def __init__(self, summary_type: InfectionSummary, threshold: float, *args: Any,
def calculate_done(self, obs: PandemicObservation, action: int) -> bool:
return bool(np.any(obs.global_infection_summary[..., self._index] > self._threshold))

def calculate_done_from_state(self, state: PandemicSimState, action: int) -> bool:
return bool(np.any(list(state.global_infection_summary.values())[self._index] > self._threshold))


class NoMoreInfectionsDone(DoneFunction):
"""Returns True if the number of infected and critical becomes zero and all have recovered."""
Expand All @@ -114,6 +124,16 @@ def calculate_done(self, obs: PandemicObservation, action: int) -> bool:
self._cnt = 0
return False

def calculate_done_from_state(self, state: PandemicSimState, action: int) -> bool:
no_infection = (list(state.global_infection_summary.values())[self._infected_index]+ list(state.global_infection_summary.values())[self._critical_index] == 0)
if no_infection and self._cnt > 5:
return True
elif no_infection:
self._cnt += 1
else:
self._cnt = 0
return False

def reset(self) -> None:
self._cnt = 0

Expand All @@ -136,6 +156,10 @@ def calculate_done(self, obs: PandemicObservation, action: int) -> bool:
self._pandemic_exists = self._pandemic_exists or np.any(obs.infection_above_threshold)
return obs.time_day[-1].item() > self._num_days and not self._pandemic_exists

def calculate_done_from_state(self, state: PandemicSimState, action: int) -> bool:
self._pandemic_exists = self._pandemic_exists or np.any(state.infection_above_threshold)
return state.sim_time.day > self._num_days and not self._pandemic_exists


_register_done(DoneFunctionType.INFECTION_SUMMARY_ABOVE_THRESHOLD, InfectionSummaryAboveThresholdDone)
_register_done(DoneFunctionType.NO_MORE_INFECTIONS, NoMoreInfectionsDone)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .location_states import NonEssentialBusinessLocationState
from .sim_state import PandemicSimState

__all__ = ['PandemicObservation']
__all__ = ['PandemicObservation','PanObsOptional']


@dataclass
Expand Down Expand Up @@ -79,3 +79,71 @@ def update_obs_with_sim_state(self, sim_state: PandemicSimState,
def infection_summary_labels(self) -> Sequence[str]:
"""Return the label for each index in global_infection(or testing)_summary observation entry"""
return [k.value for k in sorted_infection_summary]

@dataclass
class PanObsOptional(PandemicObservation):

use_gis: bool = False
use_gts: bool = True
use_stage: bool = True
use_infection_above_threshold: bool = False
#if crit flag accepted add it here
use_time_day: bool = False
use_unlocked_non_essential_business_locations: bool = False

@classmethod
def create_empty(cls: Type['PanObsOptional'],
history_size: int = 1,
num_non_essential_business: Optional[int] = None,
use_gis: bool = False,
use_gts: bool = True,
use_stage: bool = True,
use_infection_above_threshold: bool = False,
#if crit flag accepted add it here
use_time_day: bool = False,
use_unlocked_non_essential_business_locations: bool = False) -> 'PanObsOptional':

return PanObsOptional(global_infection_summary=np.zeros((history_size, 1, len(InfectionSummary))) if use_gis else None,
global_testing_summary=np.zeros((history_size, 1, len(InfectionSummary)))if use_gts else None,
stage=np.zeros((history_size, 1, 1))if use_stage else None,
infection_above_threshold=np.zeros((history_size, 1, 1))if use_infection_above_threshold else None,
time_day=np.zeros((history_size, 1, 1))if use_time_day else None,
unlocked_non_essential_business_locations=np.zeros((history_size, 1,
num_non_essential_business))
if use_unlocked_non_essential_business_locations else None)

def update_obs_with_sim_state(self, sim_state: PandemicSimState,
hist_index: int = 0,
business_location_ids: Optional[Sequence[LocationID]] = None) -> None:
"""
Update the PandemicObservation with the information from PandemicSimState.

:param sim_state: PandemicSimState instance
:param hist_index: history time index
:param business_location_ids: business location ids
"""
if self.use_unlocked_non_essential_business_locations:
if self.unlocked_non_essential_business_locations is not None and business_location_ids is not None:
unlocked_non_essential_business_locations = np.asarray([int(not cast(NonEssentialBusinessLocationState,
sim_state.id_to_location_state[
loc_id]).locked)
for loc_id in business_location_ids])
self.unlocked_non_essential_business_locations[hist_index, 0] = unlocked_non_essential_business_locations

if self.use_gis:
gis = np.asarray([sim_state.global_infection_summary[k] for k in sorted_infection_summary])[None, None, ...]
self.global_infection_summary[hist_index, 0] = gis

if self.use_gts:
gts = np.asarray([sim_state.global_testing_state.summary[k] for k in sorted_infection_summary])[None, None, ...]
self.global_testing_summary[hist_index, 0] = gts

if self.use_stage:
self.stage[hist_index, 0] = sim_state.regulation_stage

if self.use_infection_above_threshold:
self.infection_above_threshold[hist_index, 0] = int(sim_state.infection_above_threshold)

if self.use_time_day:
self.time_day[hist_index, 0] = int(sim_state.sim_time.day)

214 changes: 212 additions & 2 deletions python/pandemic_simulator/environment/pandemic_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
from typing import List, Optional, Dict, Tuple, Mapping, Type, Sequence

import gym
from pandemic_simulator.environment.interfaces.pandemic_observation import PanObsOptional
from pandemic_simulator.environment.interfaces.sim_state import PandemicSimState

from .done import DoneFunction
from .done import DoneFunction, InfectionSummaryAboveThresholdDone, NoMoreInfectionsDone, NoPandemicDone, ORDone
from .interfaces import LocationID, PandemicObservation, NonEssentialBusinessLocationState, PandemicRegulation, \
InfectionSummary
from .pandemic_sim import PandemicSim
from .reward import RewardFunction, SumReward, RewardFunctionFactory, RewardFunctionType
from .simulator_config import PandemicSimConfig
from .simulator_opts import PandemicSimOpts

__all__ = ['PandemicGymEnv']
__all__ = ['PandemicGymEnv','PandemicGymEnv3Act','PandemicGymEnvCusObs']


class PandemicGymEnv(gym.Env):
Expand Down Expand Up @@ -169,3 +171,211 @@ def reset(self) -> PandemicObservation:

def render(self, mode: str = 'human') -> bool:
pass

class PandemicGymEnv3Act(gym.ActionWrapper):
def __init__(self, env: PandemicGymEnv):
super().__init__(env)
self.env = env

self.action_space = gym.spaces.Discrete(3, start=-1)

@classmethod
def from_config(self,
sim_config: PandemicSimConfig,
pandemic_regulations: Sequence[PandemicRegulation],
sim_opts: PandemicSimOpts = PandemicSimOpts(),
reward_fn: Optional[RewardFunction] = None,
done_fn: Optional[DoneFunction] = None,
) -> 'PandemicGymEnv3Act':
env = PandemicGymEnv.from_config(sim_config = sim_config,
pandemic_regulations=pandemic_regulations,
sim_opts = sim_opts,
reward_fn=reward_fn,
done_fn=done_fn,
)

return PandemicGymEnv3Act(env=env)

def step(self, action):
return self.env.step(int(self.action(action)))

def action(self, action):
assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))
return min(4, max(0, self.env._last_observation.stage[-1, 0, 0] + action))

def reset(self):
self.env.reset()

class PandemicGymEnvCusObs(PandemicGymEnv):
"""A gym environment interface wrapper for the Pandemic Simulator."""

_pandemic_sim: PandemicSim
_stage_to_regulation: Mapping[int, PandemicRegulation]
_obs_history_size: int
_sim_steps_per_regulation: int
_non_essential_business_loc_ids: Optional[List[LocationID]]
_reward_fn: Optional[RewardFunction]
_done_fn: Optional[DoneFunction]

_last_observation: PanObsOptional
_last_state: PandemicSimState
_last_reward: float

def __init__(self,
pandemic_sim: PandemicSim,
pandemic_regulations: Sequence[PandemicRegulation],
reward_fn: Optional[RewardFunction] = None,
done_fn: Optional[DoneFunction] = None,
obs_history_size: int = 1,
sim_steps_per_regulation: int = 24,
non_essential_business_location_ids: Optional[List[LocationID]] = None,
):
"""
:param pandemic_sim: Pandemic simulator instance
:param pandemic_regulations: A sequence of pandemic regulations
:param reward_fn: reward function
:param done_fn: done function
:param obs_history_size: number of latest sim step states to include in the observation
:param sim_steps_per_regulation: number of sim_steps to run for each regulation
:param non_essential_business_location_ids: an ordered list of non-essential business location ids
"""
self._pandemic_sim = pandemic_sim
self._stage_to_regulation = {reg.stage: reg for reg in pandemic_regulations}
self._obs_history_size = obs_history_size
self._sim_steps_per_regulation = sim_steps_per_regulation

if non_essential_business_location_ids is not None:
for loc_id in non_essential_business_location_ids:
assert isinstance(self._pandemic_sim.state.id_to_location_state[loc_id],
NonEssentialBusinessLocationState)
self._non_essential_business_loc_ids = non_essential_business_location_ids

self._reward_fn = reward_fn
self._done_fn = ORDone([InfectionSummaryAboveThresholdDone(summary_type=InfectionSummary.INFECTED,threshold=4),NoMoreInfectionsDone(),NoPandemicDone(num_days=3)])

self.action_space = gym.spaces.Discrete(len(self._stage_to_regulation))

self._last_observation = PanObsOptional.create_empty()

@classmethod
def from_config(cls: Type['PandemicGymEnv'],
sim_config: PandemicSimConfig,
pandemic_regulations: Sequence[PandemicRegulation],
sim_opts: PandemicSimOpts = PandemicSimOpts(),
reward_fn: Optional[RewardFunction] = None,
done_fn: Optional[DoneFunction] = None,
obs_history_size: int = 1,
non_essential_business_location_ids: Optional[List[LocationID]] = None,
) -> 'PandemicGymEnvCusObs':
"""
Creates an instance using config

:param sim_config: Simulator config
:param pandemic_regulations: A sequence of pandemic regulations
:param sim_opts: Simulator opts
:param reward_fn: reward function
:param done_fn: done function
:param obs_history_size: number of latest sim step states to include in the observation
:param non_essential_business_location_ids: an ordered list of non-essential business location ids
"""
sim = PandemicSim.from_config(sim_config, sim_opts)

if sim_config.max_hospital_capacity == -1:
raise Exception("Nothing much to optimise if max hospital capacity is -1.")

reward_fn = reward_fn or SumReward(
reward_fns=[
RewardFunctionFactory.default(RewardFunctionType.INFECTION_SUMMARY_ABOVE_THRESHOLD,
summary_type=InfectionSummary.CRITICAL,
threshold=sim_config.max_hospital_capacity),
RewardFunctionFactory.default(RewardFunctionType.INFECTION_SUMMARY_ABOVE_THRESHOLD,
summary_type=InfectionSummary.CRITICAL,
threshold=3 * sim_config.max_hospital_capacity),
RewardFunctionFactory.default(RewardFunctionType.LOWER_STAGE,
num_stages=len(pandemic_regulations)),
RewardFunctionFactory.default(RewardFunctionType.SMOOTH_STAGE_CHANGES,
num_stages=len(pandemic_regulations)),
RewardFunctionFactory.default(RewardFunctionType.UNLOCKED_BUSINESS_LOCATIONS,
num_stages=len(pandemic_regulations)),
RewardFunctionFactory.default(RewardFunctionType.INFECTION_SUMMARY_ABSOLUTE,summary_type=InfectionSummary.CRITICAL,
num_stages=len(pandemic_regulations)),
RewardFunctionFactory.default(RewardFunctionType.INFECTION_SUMMARY_INCREASE,summary_type=InfectionSummary.CRITICAL,
num_stages=len(pandemic_regulations)),
RewardFunctionFactory.default(RewardFunctionType.LOWER_STAGE,
num_stages=len(pandemic_regulations)),

],
weights=[.4, 1, .1, 0.02, 1, 1 ,1 ,1]
)

return PandemicGymEnvCusObs(pandemic_sim=sim,
pandemic_regulations=pandemic_regulations,
sim_steps_per_regulation=sim_opts.sim_steps_per_regulation,
reward_fn=reward_fn,
done_fn=ORDone([InfectionSummaryAboveThresholdDone(summary_type=InfectionSummary.INFECTED,threshold=10),NoMoreInfectionsDone(),NoPandemicDone(num_days=5)]),
obs_history_size=obs_history_size,
non_essential_business_location_ids=non_essential_business_location_ids)

@property
def observation(self) -> PanObsOptional:
return self._last_observation

def step(self, action: int) -> Tuple[PanObsOptional, float, bool, Dict]:
assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))

# execute the action if different from the current stage
if action != self._pandemic_sim.state.regulation_stage: # stage has a TNC layout
regulation = self._stage_to_regulation[action]
self._pandemic_sim.impose_regulation(regulation=regulation)

# update the sim until next regulation interval trigger and construct obs from state hist
obs = PanObsOptional.create_empty(
use_gis = self._last_observation.use_gis,
use_gts = self._last_observation.use_gts,
use_stage = self._last_observation.use_stage,
use_infection_above_threshold = self._last_observation.use_infection_above_threshold,
#if crit flag accepted add it here
use_time_day = self._last_observation.use_time_day,
use_unlocked_non_essential_business_locations = self._last_observation.use_unlocked_non_essential_business_locations,
history_size=self._obs_history_size,
num_non_essential_business=len(self._non_essential_business_loc_ids)
if self._non_essential_business_loc_ids is not None else None)

hist_index = 0
for i in range(self._sim_steps_per_regulation):
# step sim
self._pandemic_sim.step()

# store only the last self._history_size state values
if i >= (self._sim_steps_per_regulation - self._obs_history_size):
obs.update_obs_with_sim_state(self._pandemic_sim.state, hist_index,
self._non_essential_business_loc_ids)
hist_index += 1

prev_state = self._last_state
self._last_reward = self._reward_fn.calculate_reward_from_state(prev_state, action, self._pandemic_sim.state, self._non_essential_business_loc_ids) if self._reward_fn else 0.
done = self._done_fn.calculate_done_from_state(self._pandemic_sim.state, action) if self._done_fn else False
self._last_observation = obs
self._last_state = self._pandemic_sim.state

return self._last_observation, self._last_reward, done, {}

def reset(self) -> PanObsOptional:
self._pandemic_sim.reset()
self._last_state = self._pandemic_sim.state
self._last_observation = PanObsOptional.create_empty(
use_gis = self._last_observation.use_gis,
use_gts = self._last_observation.use_gts,
use_stage = self._last_observation.use_stage,
use_infection_above_threshold = self._last_observation.use_infection_above_threshold,
#if crit flag accepted add it here
use_time_day = self._last_observation.use_time_day,
use_unlocked_non_essential_business_locations = self._last_observation.use_unlocked_non_essential_business_locations,
history_size=self._obs_history_size,
num_non_essential_business=len(self._non_essential_business_loc_ids)
if self._non_essential_business_loc_ids is not None else None)
self._last_reward = 0.0
if self._done_fn is not None:
self._done_fn.reset()
return self._last_observation

Loading