From df47b2e2448e396cdacde3af81e6553e12142bfe Mon Sep 17 00:00:00 2001 From: Eric Vin Date: Wed, 30 Oct 2024 09:14:35 -0700 Subject: [PATCH] Client rewrite and METS-R model expansion. --- examples/metsr/test.scenic | 21 +- src/scenic/simulators/metsr/client.py | 531 ++++--------------- src/scenic/simulators/metsr/model.scenic | 37 +- src/scenic/simulators/metsr/simulator.py | 65 +-- src/scenic/simulators/metsr/traffic_flows.py | 41 +- 5 files changed, 221 insertions(+), 474 deletions(-) diff --git a/examples/metsr/test.scenic b/examples/metsr/test.scenic index 6f2d59f9a..ea6658fc1 100644 --- a/examples/metsr/test.scenic +++ b/examples/metsr/test.scenic @@ -1,9 +1,24 @@ +param startTime = 0 +param map = "CARLA_TOWN5" model scenic.simulators.metsr.model -zone_2_center = (-0.0024190, -0.0000165, 0) -zone_9_center = (0.0013876, -0.0000135, 0) +scenario CustomCommuterTrafficStream(origin, destination): + setup: + num_commuters = Range(100, 200) + morning_peak_time = 1*60*60 # Normal(9*60*60, 30*60) + evening_peak_time = 2*60*60 # Normal(17*60*60, 30*60) + traffic_stddev = 15*60 # Normal(1*60*60, 10*60) + compose: + do CommuterTrafficStream(origin, destination, num_commuters, + morning_peak_time, evening_peak_time, traffic_stddev) scenario Main(): compose: - do ConstantTrafficStream(2,9,60) + ts_2_21 = CustomCommuterTrafficStream(2, 21) + ts_3_21 = CustomCommuterTrafficStream(3, 21) + ts_4_21 = CustomCommuterTrafficStream(4, 21) + ts_7_21 = CustomCommuterTrafficStream(7, 21) + ts_11_21 = CustomCommuterTrafficStream(11, 21) + + do ts_2_21, ts_3_21, ts_4_21, ts_7_21, ts_11_21 for 3*60*60 seconds # 16*60*60 seconds diff --git a/src/scenic/simulators/metsr/client.py b/src/scenic/simulators/metsr/client.py index d5ec39b38..4ff508f65 100644 --- a/src/scenic/simulators/metsr/client.py +++ b/src/scenic/simulators/metsr/client.py @@ -1,444 +1,143 @@ -from contextlib import closing +import datetime import json -import os -from os import path -import platform -import shutil -import socket -import subprocess -import sys -import threading -from threading import Lock import time -from types import SimpleNamespace -import zipfile -import ujson as json -import websocket +from websockets.sync.client import connect -def check_socket(host, port): - flag = True - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - if sock.connect_ex((host, port)) == 0: - flag = True - else: - flag = False - time.sleep(1) - return flag - - -# Factory for processsing a str list with a given func -def str_list_mapper_gen(func): - def str_list_mapper(str_list): - return [func(str) for str in str_list] - - return str_list_mapper - - -str_list_to_int_list = str_list_mapper_gen(int) -str_list_to_float_list = str_list_mapper_gen(float) - - -""" -Implementation of the remote data client - -A client directly communicates with a specific METSR-SIM server. -""" - - -class METSRClient(threading.Thread): - def __init__( - self, host, port, index, manager=None, retry_threshold=20, verbose=False - ): - super().__init__() - - # Websocket config +class METSRClient: + def __init__(self, host, port, max_connection_attempts=5, timeout=30, verbose=False): self.host = host self.port = port self.uri = f"ws://{host}:{port}" - self.index = index - self.state = "connecting" - self.retry_threshold = ( - retry_threshold # time out for resending the same message if no response - ) - self.verbose = verbose - # self.docker_id = docker_id - - # a pointer to the manager - self.manager = manager - - # Track the tick of the corresponding simulator - self.current_tick = -1 - self.prev_tick = -1 - self.prev_time = time.time() - # latest message from the server - self.latest_ans_message = None - self.latest_ctrl_message = None - - # Create a listener socket - self.ws = websocket.WebSocketApp( - self.uri, - on_open=self.on_open, - on_message=self.on_message, - on_error=self.on_error, - on_close=self.on_close, - ) - - # A flag to indicate whether the simulation is ready - self.ready = False - - # Data maps can be accessed by both main thread (for ML algorithms) - # and RDClient class. Therefore synchronization is needed to avoid data races. - # If the main thread only does reads, the lock can be removed safely - self.lock = Lock() - - # on_message is automatically called when the sever sends a msg - def on_message(self, ws, message): - # for debugging - if self.verbose: - print(f"{self.uri} : {message[0:200]}") - - # Decode the json string - decoded_msg = json.loads(str(message)) - - # Every decoded msg must have a MSG_TYPE field - assert "TYPE" in decoded_msg.keys(), "No TYPE field in received json string!" - - # handle decoded msg based on MSG_TYPE - if decoded_msg["TYPE"] == "STEP": - self.handle_step_message(decoded_msg) - elif decoded_msg["TYPE"].split("_")[0] == "ANS": - self.handle_answer_message(ws, decoded_msg) - self.latest_ans_message = decoded_msg - elif decoded_msg["TYPE"].split("_")[0] == "ATK": - self.handle_attack_message(ws, decoded_msg) - elif decoded_msg["TYPE"].split("_")[0] == "CTRL": - self.latest_ctrl_message = decoded_msg - - def on_error(self, ws, error): - self.state = "error" - print(error) - - def on_close(self, ws, status_code, close_msg): - self.state = "closed" - print(f"{self.uri} : connection closed") - - def on_open(self, ws): - self.state = "connected" - print(f"{self.uri} : connection opened") - - # run() method implements what RemoteDataClient will be doing during its lifetime - def run(self): - print(f"Waiting until the server is up at {self.uri}") - - # all clients are disconnected - wait_time = 0 + self.current_tick = None + self.timeout = timeout + self.verbose = verbose + self._messagesLog = [] - # wait until the server is up, or timeout after 30 seconds - while not check_socket(self.host, self.port): - # count the real-world seconds - wait_time += 1 - if wait_time > 20: - print( - f"Waiting for the server to be up at {self.uri}.. time out in {30-wait_time} seconds.." - ) - if wait_time > 30: + # Establish connection + failed_attempts = 0 + while True: + try: + self.ws = connect(self.uri) + break + except ConnectionRefusedError: print( - "Waiting overtime, please check the connection and restart the simulation." + f"Attempt to connect to {self.uri} failed. " + f"Waiting for 10 seconds before trying again... " + f"({max_connection_attempts - failed_attempts} attempts remaining)" ) - # close the connection - self.ws.close() - os.chdir("docker") - os.system("docker-compose down") - break - - if wait_time <= 30: - print(f"Sever is active at {self.uri}, running client..") - self.ws.run_forever() - - # Method for handle messages - def handle_step_message(self, decoded_msg): - with self.lock: # for thread safety - tick = decoded_msg["TICK"] - if not self.ready: - self.send_step_message( - 0 - ) # This will initialize the simulator data structures if it is first ran, otherwise, it will be ignored - self.ready = True - if ( - tick > self.current_tick - ): # tick smaller or equal to the current_tick is ignored - self.current_tick = tick - - def handle_answer_message(self, ws, decoded_msg): - # if decoded_msg['TYPE'] == "ANS_ready": - # print("SIM is ready!!") - # self.ready = True - if decoded_msg["TYPE"] == "ANS_TaxiUCB": - size = int(decoded_msg["SIZE"]) - candidate_paths = {} - od = decoded_msg["OD"] - candidate_paths[od] = decoded_msg["road_lists"] - self.manager.mab_manager.initialize(candidate_paths, size, type="taxi") - elif decoded_msg["TYPE"] == "ANS_BusUCB": - size = int(decoded_msg["SIZE"]) - candidate_paths = {} - bod = decoded_msg["BOD"] - candidate_paths[bod] = decoded_msg["road_lists"] - self.manager.mab_manager.initialize(candidate_paths, size, type="bus") - - def handle_attack_message(self, ws, decoded_msg): - # placeholder for handling attacker's control - pass - - def send_step_message(self, tick): # helper function for sending step message - self.prev_tick = tick - self.prev_time = time.time() - msg = {"TYPE": "STEP", "TICK": tick} - self.ws.send(json.dumps(msg)) - - def tick( - self, - ): # synchronized, wait until the simulator finish the corresponding step - while self.current_tick <= self.prev_tick: - time.sleep(0.001) - - self.send_step_message(self.current_tick) - - while self.current_tick + 1 <= self.prev_tick: - time.sleep(0.001) - if time.time() - self.prev_time > self.retry_threshold: - return False, f"Tick time out, the current tick is {self.current_tick}" - - def send_query_message( - self, msg - ): # asynchronized, other tasks can be done while waiting for the answer - # time.sleep(0.005) # wait for some time to avoid blocking the message pending - while not self.ready: - time.sleep(1) - - self.prev_time = time.time() - self.ws.send(json.dumps(msg)) + failed_attempts += 1 + if failed_attempts >= max_connection_attempts: + raise RuntimeError("Could not connect to METS-R Sim") + time.sleep(10) - ans_type = msg["TYPE"].replace("QUERY", "ANS") - while ( - self.latest_ans_message is None or self.latest_ans_message["TYPE"] != ans_type - ): - time.sleep(0.001) - if time.time() - self.prev_time > self.retry_threshold: - return False, f"Query time out, the message is {msg}" - res = self.latest_ans_message.copy() - self.latest_ans_message = None - return True, res + # Ensure server is initialized by waiting to receive an initial packet + # (could be ANS_ready or a heartbeat) + self.receive_msg(ignore_heartbeats=False) - def send_control_message(self, msg): # synchronized, wait until receive the answer - # time.sleep(0.005) # wait for some time to avoid blocking the message pending - while not self.ready: - time.sleep(1) + def send_msg(self, msg): + if self.verbose: + self._logMessage("SENT", msg) self.ws.send(json.dumps(msg)) - sent_time = time.time() - # wait until receive the answer or time out - while ( - self.latest_ctrl_message is None - or self.latest_ctrl_message["TYPE"] != msg["TYPE"] - ): - time.sleep(0.001) - if time.time() - sent_time > self.retry_threshold: - return False, f"Control time out, the message is {msg}" - res = self.latest_ctrl_message.copy() - self.latest_ctrl_message = None - if res["CODE"] == "OK": - if msg["TYPE"] == "CTRL_reset": - self.current_tick = -1 - self.prev_tick = -1 - return True, res - else: - return False, f"Control failed, the reply is {res}" - - # QUERY: inspect the state of the simulator - # By default query public vehicles - def query_vehicle(self, id=None, private_veh=False, transform_coords=False): - my_msg = {} - my_msg["TYPE"] = "QUERY_vehicle" - if id is not None: - my_msg["ID"] = id - my_msg["PRV"] = private_veh - my_msg["TRAN"] = transform_coords - return self.send_query_message(my_msg) - - # query taxi - def query_taxi(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_taxi" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - - # query bus - def query_bus(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_bus" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - - # query road - def query_road(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_road" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - - # query zone - def query_zone(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_zone" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - - # query signal - def query_signal(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_signal" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - # query chargingStation - def query_chargingStation(self, id=None): - my_msg = {} - my_msg["TYPE"] = "QUERY_chargingStation" - if id is not None: - my_msg["ID"] = id - return self.send_query_message(my_msg) - - # query vehicleID within the co-sim road - def query_coSimVehicle(self): - my_msg = {} - my_msg["TYPE"] = "QUERY_coSimVehicle" - return self.send_query_message(my_msg) - - # CONTROL: change the state of the simulator - # set the road for co-simulation - def set_cosim_road(self, roadID): - my_msg = {} - my_msg["TYPE"] = "CTRL_setCoSimRoad" - my_msg["roadID"] = roadID - return self.send_control_message(my_msg) - - # release the road for co-simulation - def release_cosim_road(self, roadID): - my_msg = {} - my_msg["TYPE"] = "CTRL_releaseCoSimRoad" - my_msg["roadID"] = roadID - return self.send_control_message(my_msg) - - # teleport vehicle to a target location specified by road, lane, and distance to the downstream junction - def teleport_vehicle( - self, vehID, roadID, laneID, dist, x, y, private_veh=False, transform_coords=False - ): - my_msg = {} - my_msg["TYPE"] = "CTRL_teleportVeh" - my_msg["vehID"] = vehID - my_msg["roadID"] = roadID - my_msg["laneID"] = laneID - my_msg["dist"] = dist - my_msg["prv"] = private_veh - my_msg["x"] = x - my_msg["y"] = y - my_msg["TRAN"] = transform_coords - return self.send_control_message(my_msg) - - # enter the next road - def enter_next_road(self, vehID, private_veh=False): - my_msg = {} - my_msg["TYPE"] = "CTRL_enterNextRoad" - my_msg["vehID"] = vehID - my_msg["prv"] = private_veh - return self.send_control_message(my_msg) + def receive_msg(self, ignore_heartbeats): + while True: + raw_msg = self.ws.recv(timeout=self.timeout) + + # Decode the json string + msg = json.loads(str(raw_msg)) + + if self.verbose: + self._logMessage("RCVD", msg) + + # Every decoded msg must have a MSG_TYPE field + assert "TYPE" in msg.keys(), "No type field in received message" + assert msg["TYPE"].split("_")[0] in { + "STEP", + "ANS", + "CTRL", + "ATK", + }, "Uknown message type: " + str(msg["TYPE"]) + + # Ignore certain message types entirely + if msg["TYPE"] in {"ANS_ready"}: + continue + + # Return decoded message, if it's not an ignored heartbeat + if not ignore_heartbeats or msg["TYPE"] != "STEP": + return msg + + def send_receive_msg(self, msg, ignore_heartbeats): + self.send_msg(msg) + return self.receive_msg(ignore_heartbeats=ignore_heartbeats) + + def tick(self): + assert ( + self.current_tick is not None + ), "self.current_tick is None. Reset should be called first" + msg = {"TYPE": "STEP", "TICK": self.current_tick} + self.send_msg(msg) + + while True: + # Move through messages until we get to an up to date heartbeat + res = self.receive_msg(ignore_heartbeats=False) + + assert res["TYPE"] == "STEP", res["TYPE"] + if res["TICK"] == self.current_tick + 1: + break - # generate a vehicle trip - def generate_trip(self, vehID, origin=None, destination=None): - my_msg = {} - my_msg["TYPE"] = "CTRL_generateTrip" - my_msg[ - "vehID" - ] = vehID # if not exists, the sim will generate a new vehicle with this vehID - if origin is not None: - my_msg["origin"] = origin - else: - my_msg["origin"] = -1 - if destination is not None: - my_msg["destination"] = destination - else: - my_msg["destination"] = -1 - return self.send_control_message(my_msg) + self.current_tick = res["TICK"] - # control vehicle with specified acceleration - def control_vehicle(self, vehID, acc, private_veh=False): - my_msg = {} - my_msg["TYPE"] = "CTRL_controlVeh" - my_msg["vehID"] = vehID - my_msg["acc"] = acc - my_msg["prv"] = private_veh - return self.send_control_message(my_msg) + def generate_trip(self, vehID, origin, destination): + msg = { + "TYPE": "CTRL_generateTrip", + "vehID": vehID, + "origin": origin, + "destination": destination, + } - # reset the simulation with a property file - def reset(self, prop_file): - # print current working directory - # print(f"Current working directory: {os.getcwd()}") - # print(f"Docker ID: {self.docker_id}") - # copy prop_file (a file) to the sim folder - # docker_cp_command = f"docker cp data/{prop_file} {self.docker_id}:/home/test/data/" - # subprocess.run(docker_cp_command, shell=True, check=True) + res = self.send_receive_msg(msg, ignore_heartbeats=True) - my_msg = {} - my_msg["TYPE"] = "CTRL_reset" - my_msg["propertyFile"] = prop_file - self.latest_ans_message = None - self.latest_ctrl_message = None + assert res["TYPE"] == "CTRL_generateTrip", res["TYPE"] + assert res["CODE"] == "OK", res["CODE"] - return self.send_control_message(my_msg) + def query_vehicle(self, vehID, private_veh=False, transform_coords=False): + msg = { + "TYPE": "QUERY_vehicle", + "ID": vehID, + "PRV": private_veh, + "TRAN": transform_coords, + } - # reset the simulation with a map name - def reset_map(self, map_name): - # find the property file for the map - if map_name == "CARLA": - # copy CARLA data in the sim folder - # source_path = "data/CARLA" - # specify the property file - prop_file = "Data.properties.CARLA" - elif map_name == "NYC": - # copy NYC data in the sim folder - # source_path = "data/NYC" - # specify the property file - prop_file = "Data.properties.NYC" + res = self.send_receive_msg(msg, ignore_heartbeats=True) - # docker_cp_command = f"docker cp {source_path} {self.docker_id}:/home/test/data/" - # subprocess.run(docker_cp_command, shell=True, check=True) + assert res["TYPE"] == "ANS_vehicle", res["TYPE"] + return res - # reset the simulation with the property file - self.reset(prop_file) + def reset(self, props_file): + msg = {"TYPE": "CTRL_reset", "propertyFile": props_file} + res = self.send_receive_msg(msg, ignore_heartbeats=True) - # terminate the simulation - def terminate(self): - my_msg = {} - my_msg["TYPE"] = "CTRL_end" - return self.send_control_message(my_msg) + assert res["TYPE"] == "CTRL_reset", res["TYPE"] + assert res["CODE"] == "OK", res["CODE"] - # override __str__ for logging - def __str__(self): - s = ( - f"-----------\n" - f"Client INFO\n" - f"-----------\n" - f"index :\t {self.index}\n" - f"address :\t {self.uri}\n" - f"state :\t {self.state}\n" + self.current_tick = -1 + self.tick() + assert self.current_tick == 0 + + def close(self): + if self.ws is not None: + self.ws.close() + self.ws = None + + def _logMessage(self, direction, msg): + self._messagesLog.append( + ( + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + direction, + tuple(msg.items()), + ) ) - return s diff --git a/src/scenic/simulators/metsr/model.scenic b/src/scenic/simulators/metsr/model.scenic index a8f5bacc6..af09ad77a 100644 --- a/src/scenic/simulators/metsr/model.scenic +++ b/src/scenic/simulators/metsr/model.scenic @@ -1,22 +1,25 @@ +import warnings + from scenic.simulators.metsr.simulator import METSRSimulator from scenic.simulators.metsr.traffic_flows import * # Default start time is 6:00 AM param startTime = 6*60*60 -param timestep = 0.1 +param timestep = 1 +param simTimestep = 0.1 simulator METSRSimulator( host="localhost", port=4000, map_name="Data.properties.CARLA", - timestep=globalParameters.timestep + timestep=globalParameters.timestep, + sim_timestep=globalParameters.simTimestep, ) -# Internal time in seconds -_internalTime = globalParameters.startTime - _DAY_MOD = 24*60*60 +def currentTOD(): + return (simulation().currentTime * simulation().timestep + globalParameters.startTime)%_DAY_MOD class PrivateCar: pass @@ -28,11 +31,29 @@ scenario GeneratePrivateTrip(origin, destination): scenario TrafficStream(origin, destination, traffic_flow): compose: while True: - if Range(0,1) < traffic_flow.probSpawn(_internalTime%_DAY_MOD, globalParameters.timestep): + raw_prob_spawn = traffic_flow.expected_vehs( + currentTOD(), currentTOD()+simulation().timestep) + if raw_prob_spawn < 0 or raw_prob_spawn > 1: + warnings.warn(f"raw_prob_spawn (={raw_prob_spawn}) fell outside [0,1] and will be clamped.") + prob_spawn = min(1, max(raw_prob_spawn, 0)) + if Range(0,1) < prob_spawn: do GeneratePrivateTrip(origin, destination) else: wait -scenario ConstantTrafficStream(origin, destination, vph): +scenario ConstantTrafficStream(origin, destination, num_vehicles, stime=None, etime=None): + compose: + tf = ConstantTrafficFlow(num_vehicles, stime, etime) + do TrafficStream(origin, destination, tf) + +scenario NormalTrafficStream(origin, destination, num_vehicles, peak_time, stddev): + compose: + tf = NormalTrafficFlow(num_vehicles, peak_time, stddev) + do TrafficStream(origin, destination, tf) + +scenario CommuterTrafficStream(origin, destination, num_vehicles, + peak_time_1, peak_time_2, stddev): compose: - do TrafficStream(origin, destination, ConstantTrafficFlow(vph)) + tf1 = NormalTrafficFlow(num_vehicles, peak_time_1, stddev) + tf2 = NormalTrafficFlow(num_vehicles, peak_time_2, stddev) + do TrafficStream(origin, destination, tf1), TrafficStream(destination, origin, tf2) diff --git a/src/scenic/simulators/metsr/simulator.py b/src/scenic/simulators/metsr/simulator.py index e79b79316..4ba38da19 100644 --- a/src/scenic/simulators/metsr/simulator.py +++ b/src/scenic/simulators/metsr/simulator.py @@ -1,50 +1,54 @@ """Simulator interface for METS-R Sim.""" -import datetime import math -import time from scenic.core.simulators import Simulation, Simulator from scenic.core.vectors import Orientation, Vector from scenic.simulators.metsr.client import METSRClient -_LOG_CLIENT_CALLS = True - class METSRSimulator(Simulator): - def __init__(self, host, port, map_name, timestep=0.1): + def __init__(self, host, port, map_name, timestep, sim_timestep): super().__init__() - self.client = METSRClient(host=host, port=port, index=42, verbose=True) - self.client.start() + self.client = METSRClient(host=host, port=port) self.map_name = map_name self.timestep = timestep + self.sim_timestep = sim_timestep def createSimulation(self, scene, timestep, **kwargs): assert timestep is None or timestep == self.timestep - return METSRSimulation(scene, self.client, self.map_name, self.timestep, **kwargs) + return METSRSimulation( + scene, self.client, self.map_name, self.timestep, self.sim_timestep, **kwargs + ) def destroy(self): - self.client.ws.close() + self.client.close() super().destroy() class METSRSimulation(Simulation): - def __init__(self, scene, client, map_name, timestep, **kwargs): + def __init__(self, scene, client, map_name, timestep, sim_timestep, **kwargs): self.client = client self.map_name = map_name + self.timestep = timestep + self.sim_timestep = sim_timestep + self.sim_ticks_per = int(timestep / sim_timestep) + assert self.sim_ticks_per == timestep / sim_timestep self.next_pv_id = 0 self.pv_id_map = {} self._client_calls = [] + self.count = 0 + super().__init__(scene, timestep=timestep, **kwargs) def setup(self): # Reset map - self.client.reset_map("CARLA") + self.client.reset("Data.properties.CARLA") super().setup() # Calls createObjectInSimulator for each object @@ -52,40 +56,29 @@ def createObjectInSimulator(self, obj): assert obj.origin assert obj.destination - import time - - start_time = time.time() - call_kwargs = { "vehID": self.getPrivateVehId(obj), "origin": obj.origin, "destination": obj.destination, } - if _LOG_CLIENT_CALLS: - self._logClientCall("GENERATE_TRIP", tuple(call_kwargs.items())) - - success = self.client.generate_trip(**call_kwargs) - assert success + self.client.generate_trip(**call_kwargs) def step(self): - if _LOG_CLIENT_CALLS: - self._logClientCall("TICK", tuple()) - - self.client.tick() + self.count += 1 + if self.count % 100 == 0: + print(".", end="", flush=True) + for _ in range(self.sim_ticks_per): + self.client.tick() def getProperties(self, obj, properties): call_kwargs = { - "id": self.getPrivateVehId(obj), + "vehID": self.getPrivateVehId(obj), "private_veh": True, "transform_coords": True, } - if _LOG_CLIENT_CALLS: - self._logClientCall("QUERY_VEHICLE", tuple(call_kwargs.items())) - - success, raw_data = self.client.query_vehicle(**call_kwargs) - assert success + raw_data = self.client.query_vehicle(**call_kwargs) position = Vector(raw_data["x"], raw_data["y"], 0) speed = raw_data["speed"] @@ -109,13 +102,12 @@ def getProperties(self, obj, properties): return values def destroy(self): - if _LOG_CLIENT_CALLS: - print("Client Calls:") + if self.client.verbose: + print("Client Messages Log:") print("[") - for call in self._client_calls: + for call in self.client._messagesLog: print(f" {call},") print("]") - pass def getPrivateVehId(self, obj): if obj not in self.pv_id_map: @@ -123,8 +115,3 @@ def getPrivateVehId(self, obj): self.next_pv_id += 1 return self.pv_id_map[obj] - - def _logClientCall(self, type, args): - self._client_calls.append( - (datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), type, args) - ) diff --git a/src/scenic/simulators/metsr/traffic_flows.py b/src/scenic/simulators/metsr/traffic_flows.py index 0a6b39bb2..38b5f0359 100644 --- a/src/scenic/simulators/metsr/traffic_flows.py +++ b/src/scenic/simulators/metsr/traffic_flows.py @@ -1,18 +1,43 @@ from abc import abstractmethod +import math + +from scenic.core.utils import sqrt2 class TrafficFlow: @abstractmethod - def vps(self, time): + def expected_vehs(self, stime, etime): pass - def probSpawn(self, time, timestep): - return self.vps(time) * timestep - class ConstantTrafficFlow(TrafficFlow): - def __init__(self, vph): - self.vph = vph + def __init__(self, num_vehs, stime=None, etime=None): + self.num_vehs = num_vehs + self.stime = stime if stime is not None else 0 + self.etime = etime if stime is not None else 24 * 60 * 60 + if etime <= stime: + raise ValueError("etime must be greater than stime.") + + self.vps = self.num_vehs / (etime - stime) + + def expected_vehs(self, stime, etime): + if etime <= stime: + raise ValueError("etime must be greater than stime.") + + clamped_stime = min(self.etime, max(stime, self.stime)) + clamped_etime = min(self.etime, max(etime, self.stime)) + + return (clamped_etime - clamped_stime) * self.vps + + +class NormalTrafficFlow(TrafficFlow): + def __init__(self, num_vehs, mean, stddev): + self.num_vehs = num_vehs + self.mean = mean + self.stddev = stddev + + def expected_vehs(self, stime, etime): + return self.num_vehs * (self.cdf(etime) - self.cdf(stime)) - def vps(self, time): - return self.vph / (60 * 60) + def cdf(self, x): + return (1 + math.erf((x - self.mean) / (sqrt2 * self.stddev))) / 2