Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason committed Mar 27, 2024
1 parent f71f0b4 commit 57c36a0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
load_dotenv()

class Data:
def setup_redis(self):
if os.environ.get("REDIS_URL"):
return redis.Redis.from_url(os.environ.get("REDIS_URL"))
else:
return None
def setup_database(self):
db_settings = {
"database_host": os.environ.get("DB_HOST"),
Expand All @@ -27,6 +32,7 @@ def setup_database(self):
"ssh_private_key": os.environ.get("SSH_PRIVATE_KEY")
}
cursor = None
conn = None
if ssh_host:
with SSHTunnelForwarder(
(ssh_settings["ssh_host"], ssh_settings["ssh_port"]),
Expand All @@ -43,7 +49,7 @@ def setup_database(self):
database=db_settings["database_name"],
)
cursor = conn.cursor()
else:
elif db_settings["database_host"]:
conn = psycopg2.connect(
user=db_settings["database_username"],
password=db_settings["database_password"],
Expand Down Expand Up @@ -89,6 +95,6 @@ def add_to_logs(self,obj):
self.cursor.execute(
"INSERT INTO timelines (sim_id, step, substep, step_type, data) VALUES (%s, %s, %s, %s, %s)",
(self.id, self.cur_step, self.current_substep, obj["step_type"], json.dumps(obj)),)
print("insert")
#print("insert")

self.current_substep += 1
2 changes: 1 addition & 1 deletion src/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, config={}):
self.perception_range = PERCEPTION_RANGE
self.allow_movement = ALLOW_MOVEMENT
self.model = MODEL
self.redis_connection = None
self.redis_connection = self.setup_redis()

self.replay = None
self.cursor = self.setup_database()
Expand Down
2 changes: 1 addition & 1 deletion src/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def all_env_vars(self):
"total_dead": sum(1 for agent in self.agents if agent.status == 'dead'),
"total_alive": sum(1 for agent in self.agents if agent.status != 'dead'),
"llm_call_counter": llm.call_counter,
"avg_llm_calls_per_step": llm.call_counter / self.steps
"avg_llm_calls_per_step": llm.call_counter / self.steps,
"avg_runtime_per_step": total_seconds / self.steps,
}
def run_interviews(self):
Expand Down

0 comments on commit 57c36a0

Please sign in to comment.