Skip to content

Commit

Permalink
Improved documentation. Removed unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerteetaert committed Oct 9, 2023
1 parent 40ce90d commit d186f63
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 96 deletions.
93 changes: 33 additions & 60 deletions gym_pybullet_drones/envs/CFAviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,20 @@ class CFAviary(BaseAviary):
################################################################################

def __init__(self,
firmware_freq: int = 500,
ctrl_freq: int = 25,
verbose=False,
drone_model: DroneModel=DroneModel.CF2X,
num_drones: int=1,
neighbourhood_radius: float=np.inf,
initial_xyzs=None,
initial_rpys=None,
physics: Physics=Physics.PYB,
pyb_freq: int = 500,
ctrl_freq: int = 25,
gui=False,
record=False,
obstacles=False,
user_debug_gui=True,
output_folder='results',
udp_ip="127.0.0.1"
verbose=False,
):
"""Initialization of an aviary environment for use of BetaFlight controller.
Expand Down Expand Up @@ -79,15 +77,19 @@ def __init__(self,
udp_ip : base ip for betaflight controller emulator
"""
assert (pyb_freq % firmware_freq == 0), f"pyb_freq must be a multiple of firmware_freq for CFAviary {pyb_freq} {firmware_freq}"
firmware_freq = 500 if self.CONTROLLER == "mellinger" else 1000
assert (pyb_freq % firmware_freq == 0), f"pyb_freq ({pyb_freq}) must be a multiple of firmware_freq ({firmware_freq}) for CFAviary."
if num_drones != 1:
raise NotImplementedError("Multi-agent support for CF Aviary is not yet implemented.")

super().__init__(drone_model=drone_model,
num_drones=num_drones,
neighbourhood_radius=neighbourhood_radius,
initial_xyzs=initial_xyzs,
initial_rpys=initial_rpys,
physics=physics,
pyb_freq=pyb_freq,
ctrl_freq=pyb_freq,
ctrl_freq=firmware_freq, # ctrl_freq in this variable (self.CTRL_FREQ) corresponds to ctrl rate from aviary, in this case, firmware_freq
gui=gui,
record=record,
obstacles=obstacles,
Expand All @@ -107,7 +109,6 @@ def __init__(self,

self._initalize_cffirmware()


def _initalize_cffirmware(self):
"""Resets the firmware_wrapper object.
Expand Down Expand Up @@ -163,7 +164,6 @@ def _initalize_cffirmware(self):

# Reset environment
init_obs, init_info = super().reset()
print(init_obs, init_obs.shape)
init_pos=np.array([init_obs[0][0], init_obs[0][1], init_obs[0][2]]) # global coord, m
init_vel=np.array([init_obs[0][10], init_obs[0][11], init_obs[0][12]]) # global coord, m/s
init_rpy = np.array([init_obs[0][7], init_obs[0][8], init_obs[0][9]]) # body coord, rad
Expand All @@ -181,21 +181,9 @@ def _initalize_cffirmware(self):

# Initialize visualization tools
self.first_motor_killed_print = True
print(init_info)
# self.pyb_clinet = init_info['pyb_client']
self.last_visualized_setpoint = None

self.results_dict = { 'obs': [],
'reward': [],
'done': [],
'info': [],
'action': [],
}

return init_obs, init_info

################################################################################

def step(self, i):
'''Step the firmware_wrapper class and its environment.
This function should be called once at the rate of ctrl_freq. Step processes and high level commands,
Expand Down Expand Up @@ -232,7 +220,7 @@ def step(self, i):
# Update state
state_timestamp = int(self.tick / self.firmware_freq * 1e3)
if self.STATE_DELAY:
raise NotImplementedError("State delay is not implemented. Leave at 0.")
raise NotImplementedError("State delay is not yet implemented. Leave at 0.")
self._update_state(state_timestamp, *self.state_history[0])
self.state_history = self.state_history[1:] + [[cur_pos, cur_vel, cur_acc, cur_rpy * self.RAD_TO_DEG]]
else:
Expand Down Expand Up @@ -267,7 +255,6 @@ def step(self, i):
if self.first_motor_killed_print:
print("Drone firmware error. Motors are killed.")
self.first_motor_killed_print = False
done = True

self.action = action

Expand All @@ -278,19 +265,10 @@ def _update_initial_state(self, obs):
self.prev_rpy = np.array([obs[7], obs[8], obs[9]])


def close_results_dict(self):
"""Cleanup the rtesults dict and munchify it.
##################################
########## Sensor Data ###########
##################################

"""
self.results_dict['obs'] = np.vstack(self.results_dict['obs'])
self.results_dict['reward'] = np.vstack(self.results_dict['reward'])
self.results_dict['done'] = np.vstack(self.results_dict['done'])
self.results_dict['info'] = np.vstack(self.results_dict['info'])
self.results_dict['action'] = np.vstack(self.results_dict['action'])

self.results_dict = munchify(self.results_dict)

#region Sensor update
def _update_sensorData(self, timestamp, acc_vals, gyro_vals, baro_vals=[1013.25, 25]):
'''
Axis3f acc; // Gs
Expand All @@ -312,19 +290,16 @@ def _update_sensorData(self, timestamp, acc_vals, gyro_vals, baro_vals=[1013.25,
self.sensorData.interruptTimestamp = timestamp
self.sensorData_set = True


def _update_gyro(self, x, y, z):
self.sensorData.gyro.x = firm.lpf2pApply(self.gyrolpf[0], x)
self.sensorData.gyro.y = firm.lpf2pApply(self.gyrolpf[1], y)
self.sensorData.gyro.z = firm.lpf2pApply(self.gyrolpf[2], z)


def _update_acc(self, x, y, z):
self.sensorData.acc.x = firm.lpf2pApply(self.acclpf[0], x)
self.sensorData.acc.y = firm.lpf2pApply(self.acclpf[1], y)
self.sensorData.acc.z = firm.lpf2pApply(self.acclpf[2], z)


def _update_baro(self, baro, pressure, temperature):
'''
pressure: hPa
Expand All @@ -334,9 +309,12 @@ def _update_baro(self, baro, pressure, temperature):
baro.pressure = pressure #* 0.01 Best guess is this is because the sensor encodes raw reading two decimal places and stores as int
baro.temperature = temperature
baro.asl = (((1015.7 / baro.pressure)**0.1902630958 - 1) * (25 + 273.15)) / 0.0065
#endregion


##################################
######### State Update ###########
##################################

#region State update
def _update_state(self, timestamp, pos, vel, acc, rpy, quat=None):
'''
attitude_t attitude; // deg (legacy CF2 body coordinate system, where pitch is inverted)
Expand All @@ -354,14 +332,12 @@ def _update_state(self, timestamp, pos, vel, acc, rpy, quat=None):
self._update_3D_vec(self.state.acc, timestamp, *acc)
self.state_set = True


def _update_3D_vec(self, point, timestamp, x, y, z):
point.x = x
point.y = y
point.z = z
point.timestamp = timestamp


def _update_attitudeQuaternion(self, quaternion_t, timestamp, qx, qy, qz, qw=None):
'''Updates attitude quaternion.
Expand All @@ -378,16 +354,17 @@ def _update_attitudeQuaternion(self, quaternion_t, timestamp, qx, qy, qz, qw=Non
quaternion_t.z = qz
quaternion_t.w = qw


def _update_attitude_t(self, attitude_t, timestamp, roll, pitch, yaw):
attitude_t.timestamp = timestamp
attitude_t.roll = roll
attitude_t.pitch = -pitch # Legacy representation in CF firmware
attitude_t.yaw = yaw
#endregion

################################################################################
#region Controller

##################################
########### Controller ###########
##################################

def _step_controller(self):
if not (self.sensorData_set):
print("WARNING: sensorData has not been updated since last controller call.")
Expand Down Expand Up @@ -471,7 +448,6 @@ def sendFullStateCmd(self, pos, vel, acc, yaw, rpy_rate, timestep):
"""
self.command_queue += [['_sendFullStateCmd', [pos, vel, acc, yaw, rpy_rate, timestep]]]


def _sendFullStateCmd(self, pos, vel, acc, yaw, rpy_rate, timestep):
# print(f"INFO_{self.tick}: Full state command sent.")
self.setpoint.position.x = pos[0]
Expand Down Expand Up @@ -525,7 +501,6 @@ def _sendTakeoffCmd(self, height, duration):
firm.crtpCommanderHighLevelTakeoff(height, duration)
self.full_state_cmd_override = False


def sendTakeoffYawCmd(self, height, duration, yaw):
"""Adds a takeoffyaw command to command processing queue.
Expand All @@ -540,7 +515,6 @@ def _sendTakeoffYawCmd(self, height, duration, yaw):
firm.crtpCommanderHighLevelTakeoffYaw(height, duration, yaw)
self.full_state_cmd_override = False


def sendTakeoffVelCmd(self, height, vel, relative):
"""Adds a takeoffvel command to command processing queue.
Expand All @@ -555,7 +529,6 @@ def _sendTakeoffVelCmd(self, height, vel, relative):
firm.crtpCommanderHighLevelTakeoffWithVelocity(height, vel, relative)
self.full_state_cmd_override = False


def sendLandCmd(self, height, duration):
"""Adds a land command to command processing queue.
Expand All @@ -569,7 +542,6 @@ def _sendLandCmd(self, height, duration):
firm.crtpCommanderHighLevelLand(height, duration)
self.full_state_cmd_override = False


def sendLandYawCmd(self, height, duration, yaw):
"""Adds a landyaw command to command processing queue.
Expand All @@ -584,7 +556,6 @@ def _sendLandYawCmd(self, height, duration, yaw):
firm.crtpCommanderHighLevelLandYaw(height, duration, yaw)
self.full_state_cmd_override = False


def sendLandVelCmd(self, height, vel, relative):
"""Adds a landvel command to command processing queue.
Expand All @@ -599,7 +570,6 @@ def _sendLandVelCmd(self, height, vel, relative):
firm.crtpCommanderHighLevelLandWithVelocity(height, vel, relative)
self.full_state_cmd_override = False


def sendStopCmd(self):
"""Adds a stop command to command processing queue.
"""
Expand All @@ -609,7 +579,6 @@ def _sendStopCmd(self):
firm.crtpCommanderHighLevelStop()
self.full_state_cmd_override = False


def sendGotoCmd(self, pos, yaw, duration_s, relative):
"""Adds a goto command to command processing queue.
Expand Down Expand Up @@ -637,8 +606,12 @@ def _notifySetpointStop(self):
self.full_state_cmd_override = False


##################################
###### Hardware Functions ########
##################################

BRUSHED = True
SUPPLY_VOLTAGE = 3 # QUESTION: Is change of battery life worth simulating?
SUPPLY_VOLTAGE = 3
def _motorsGetPWM(self, thrust):
if (self.BRUSHED):
thrust = thrust / 65536 * 60
Expand All @@ -650,15 +623,13 @@ def _motorsGetPWM(self, thrust):
else:
raise NotImplementedError("Emulator does not support the brushless motor configuration at this time.")


def _limitThrust(self, val):
if val > self.MAX_PWM:
return self.MAX_PWM
elif val < 0:
return 0
return val


def _powerDistribution(self, control_t):
motor_pwms = []
if self.QUAD_FORMATION_X:
Expand All @@ -679,7 +650,11 @@ def _powerDistribution(self, control_t):
self.pwms = motor_pwms
else:
self.pwms = np.clip(motor_pwms, self.MIN_PWM).tolist()
#endregion


##################################
##### Base Aviary Overrides ######
##################################

def _actionSpace(self):
"""Returns the action space of the environment.
Expand Down Expand Up @@ -810,7 +785,6 @@ def _computeInfo(self):
"""
return {"answer": 42} #### Calculated by the Deep Thought supercomputer in 7.5M years

#region Utils
def _get_quaternion_from_euler(roll, pitch, yaw):
"""Convert an Euler angle to a quaternion.
Expand All @@ -827,5 +801,4 @@ def _get_quaternion_from_euler(roll, pitch, yaw):
qz = np.cos(roll/2) * np.cos(pitch/2) * np.sin(yaw/2) - np.sin(roll/2) * np.sin(pitch/2) * np.cos(yaw/2)
qw = np.cos(roll/2) * np.cos(pitch/2) * np.cos(yaw/2) + np.sin(roll/2) * np.sin(pitch/2) * np.sin(yaw/2)

return [qx, qy, qz, qw]
#endregion
return [qx, qy, qz, qw]
Loading

0 comments on commit d186f63

Please sign in to comment.