forked from IntelLabs/coach
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtruncated_normal.py
115 lines (98 loc) · 5.35 KB
/
truncated_normal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import numpy as np
from scipy.stats import truncnorm
from rl_coach.core_types import RunPhase, ActionType
from rl_coach.exploration_policies.exploration_policy import ExplorationParameters, ContinuousActionExplorationPolicy
from rl_coach.schedules import Schedule, LinearSchedule
from rl_coach.spaces import ActionSpace, BoxActionSpace
class TruncatedNormalParameters(ExplorationParameters):
def __init__(self):
super().__init__()
self.noise_schedule = LinearSchedule(0.1, 0.1, 50000)
self.evaluation_noise = 0.05
self.clip_low = 0
self.clip_high = 1
self.noise_as_percentage_from_action_space = True
@property
def path(self):
return 'rl_coach.exploration_policies.truncated_normal:TruncatedNormal'
class TruncatedNormal(ContinuousActionExplorationPolicy):
"""
The TruncatedNormal exploration policy is intended for continuous action spaces. It samples the action from a
normal distribution, where the mean action is given by the agent, and the standard deviation can be given in t
wo different ways:
1. Specified by the user as a noise schedule which is taken in percentiles out of the action space size
2. Specified by the agents action. In case the agents action is a list with 2 values, the 1st one is assumed to
be the mean of the action, and 2nd is assumed to be its standard deviation.
When the sampled action is outside of the action bounds given by the user, it is sampled again and again, until it
is within the bounds.
"""
def __init__(self, action_space: ActionSpace, noise_schedule: Schedule,
evaluation_noise: float, clip_low: float, clip_high: float,
noise_as_percentage_from_action_space: bool = True):
"""
:param action_space: the action space used by the environment
:param noise_schedule: the schedule for the noise variance
:param evaluation_noise: the noise variance that will be used during evaluation phases
:param noise_as_percentage_from_action_space: whether to consider the noise as a percentage of the action space
or absolute value
"""
super().__init__(action_space)
self.noise_schedule = noise_schedule
self.evaluation_noise = evaluation_noise
self.noise_as_percentage_from_action_space = noise_as_percentage_from_action_space
self.clip_low = clip_low
self.clip_high = clip_high
if not isinstance(action_space, BoxActionSpace):
raise ValueError("Truncated normal exploration works only for continuous controls."
"The given action space is of type: {}".format(action_space.__class__.__name__))
if not np.all(-np.inf < action_space.high) or not np.all(action_space.high < np.inf)\
or not np.all(-np.inf < action_space.low) or not np.all(action_space.low < np.inf):
raise ValueError("Additive noise exploration requires bounded actions")
def get_action(self, action_values: List[ActionType]) -> ActionType:
# set the current noise
if self.phase == RunPhase.TEST:
current_noise = self.evaluation_noise
else:
current_noise = self.noise_schedule.current_value
# scale the noise to the action space range
if self.noise_as_percentage_from_action_space:
action_values_std = current_noise * (self.action_space.high - self.action_space.low)
else:
action_values_std = current_noise
# extract the mean values
if isinstance(action_values, list):
# the action values are expected to be a list with the action mean and optionally the action stdev
action_values_mean = action_values[0].squeeze()
else:
# the action values are expected to be a numpy array representing the action mean
action_values_mean = action_values.squeeze()
# step the noise schedule
if self.phase is not RunPhase.TEST:
self.noise_schedule.step()
# the second element of the list is assumed to be the standard deviation
if isinstance(action_values, list) and len(action_values) > 1:
action_values_std = action_values[1].squeeze()
# sample from truncated normal distribution
normalized_low = (self.clip_low - action_values_mean) / action_values_std
normalized_high = (self.clip_high - action_values_mean) / action_values_std
distribution = truncnorm(normalized_low, normalized_high, loc=action_values_mean, scale=action_values_std)
action = distribution.rvs(1)
return action
def get_control_param(self):
return np.ones(self.action_space.shape)*self.noise_schedule.current_value