diff --git a/src/data.py b/src/data.py index 24a0106..b7ce219 100644 --- a/src/data.py +++ b/src/data.py @@ -9,6 +9,11 @@ load_dotenv() class Data: + def setup_redis(self): + if os.environ.get("REDIS_URL"): + return redis.Redis.from_url(os.environ.get("REDIS_URL")) + else: + return None def setup_database(self): db_settings = { "database_host": os.environ.get("DB_HOST"), @@ -27,6 +32,7 @@ def setup_database(self): "ssh_private_key": os.environ.get("SSH_PRIVATE_KEY") } cursor = None + conn = None if ssh_host: with SSHTunnelForwarder( (ssh_settings["ssh_host"], ssh_settings["ssh_port"]), @@ -43,7 +49,7 @@ def setup_database(self): database=db_settings["database_name"], ) cursor = conn.cursor() - else: + elif db_settings["database_host"]: conn = psycopg2.connect( user=db_settings["database_username"], password=db_settings["database_password"], @@ -89,6 +95,6 @@ def add_to_logs(self,obj): self.cursor.execute( "INSERT INTO timelines (sim_id, step, substep, step_type, data) VALUES (%s, %s, %s, %s, %s)", (self.id, self.cur_step, self.current_substep, obj["step_type"], json.dumps(obj)),) - print("insert") + #print("insert") self.current_substep += 1 diff --git a/src/matrix.py b/src/matrix.py index ad41218..75d8e7d 100644 --- a/src/matrix.py +++ b/src/matrix.py @@ -71,7 +71,7 @@ def __init__(self, config={}): self.perception_range = PERCEPTION_RANGE self.allow_movement = ALLOW_MOVEMENT self.model = MODEL - self.redis_connection = None + self.redis_connection = self.setup_redis() self.replay = None self.cursor = self.setup_database() diff --git a/src/reporting.py b/src/reporting.py index 62443a0..2bfa314 100644 --- a/src/reporting.py +++ b/src/reporting.py @@ -60,7 +60,7 @@ def all_env_vars(self): "total_dead": sum(1 for agent in self.agents if agent.status == 'dead'), "total_alive": sum(1 for agent in self.agents if agent.status != 'dead'), "llm_call_counter": llm.call_counter, - "avg_llm_calls_per_step": llm.call_counter / self.steps + "avg_llm_calls_per_step": llm.call_counter / self.steps, "avg_runtime_per_step": total_seconds / self.steps, } def run_interviews(self):