Skip to content

Commit

Permalink
Check if directory exists
Browse files Browse the repository at this point in the history
  • Loading branch information
HokageM committed Dec 8, 2023
1 parent c8ec0c4 commit f12dfac
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/irlwpython/MaxEntropyDeepIRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.optim as optim
import torch.nn as nn

from irlwpython.FigurePrinter import FigurePrinter
from irlwpython.OutputHandler import OutputHandler


class QNetwork(nn.Module):
Expand All @@ -17,7 +17,7 @@ def __init__(self, input_size, output_size):
self.relu2 = nn.ReLU()
self.output_layer = nn.Linear(32, output_size)

self.printer = FigurePrinter()
self.output_hand = OutputHandler()

def forward(self, state):
x = self.fc1(state)
Expand Down
16 changes: 8 additions & 8 deletions src/irlwpython/MaxEntropyDeepRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.optim as optim
import torch.nn as nn

from irlwpython.FigurePrinter import FigurePrinter
from irlwpython.OutputHandler import OutputHandler


class QNetwork(nn.Module):
Expand All @@ -17,7 +17,7 @@ def __init__(self, input_size, output_size):
self.relu2 = nn.ReLU()
self.output_layer = nn.Linear(32, output_size)

self.printer = FigurePrinter()
self.output_hand = OutputHandler()

def forward(self, state):
x = self.fc1(state)
Expand All @@ -42,7 +42,7 @@ def __init__(self, target, state_dim, action_size, feature_matrix, one_feature,

self.gamma = gamma

self.printer = FigurePrinter()
self.output_hand = OutputHandler()

def select_action(self, state, epsilon):
"""
Expand Down Expand Up @@ -150,16 +150,16 @@ def train(self, n_states, episodes=30000, max_steps=200,
if (episode + 1) % 1000 == 0:
score_avg = np.mean(scores)
print('{} episode average score is {:.2f}'.format(episode, score_avg))
self.printer.save_plot_as_png(episode_arr, scores,
self.output_hand.save_plot_as_png(episode_arr, scores,
f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png")
self.printer.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png")
self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)),
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png")
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
f"../heatmap/theta_{episode}_deep_RL.png")

torch.save(self.q_network.state_dict(), f"../results/maxent_{episodes}_{episode}_network_main.pth")

if episode == episodes - 1:
self.printer.save_plot_as_png(episode_arr, scores,
self.output_hand.save_plot_as_png(episode_arr, scores,
f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png")

torch.save(self.q_network.state_dict(), f"src/irlwpython/results/maxentdeep_{episodes}_q_network_RL.pth")
Expand Down Expand Up @@ -192,6 +192,6 @@ def test(self, model_path, epsilon=0.01, repeats=100):
if episode % 1 == 0:
print('{} episode score is {:.2f}'.format(episode, score))

self.printer.save_plot_as_png(episodes, scores,
self.output_hand.save_plot_as_png(episodes, scores,
"src/irlwpython/learning_curves"
"/test_maxentropydeep_best_model_RL_results.png")
12 changes: 6 additions & 6 deletions src/irlwpython/MaxEntropyIRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from irlwpython.FigurePrinter import FigurePrinter
from irlwpython.OutputHandler import OutputHandler


class MaxEntropyIRL:
Expand All @@ -20,7 +20,7 @@ def __init__(self, target, feature_matrix, one_feature, q_table, q_learning_rate
self.gamma = gamma
self.n_states = n_states

self.printer = FigurePrinter()
self.output_hand = OutputHandler()

def get_feature_matrix(self):
"""
Expand Down Expand Up @@ -133,12 +133,12 @@ def train(self, theta_learning_rate, episode_count=30000):
if (episode + 1) % 1000 == 0:
score_avg = np.mean(scores)
print('{} episode score is {:.2f}'.format(episode, score_avg))
self.printer.save_plot_as_png(episodes, scores,
self.output_hand.save_plot_as_png(episodes, scores,
f"src/irlwpython/learning_curves/"
f"maxent_{episode_count}_{episode}_qtable.png")
self.printer.save_heatmap_as_png(learner.reshape((20, 20)),
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)),
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)),
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
f"src/irlwpython/heatmap/theta_{episode}_flat.png")

np.save(f"src/irlwpython/results/maxent_{episode}_qtable", arr=self.q_table)
Expand Down Expand Up @@ -172,5 +172,5 @@ def test(self, repeats=100):
if episode % 1 == 0:
print('{} episode score is {:.2f}'.format(episode, score))

self.printer.save_plot_as_png(episodes, scores,
self.output_hand.save_plot_as_png(episodes, scores,
"src/irlwpython/learning_curves/test_maxentropy_flat.png")
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import matplotlib.pyplot as plt
import os


class FigurePrinter:
class OutputHandler:
def __int__(self):
pass

Expand All @@ -25,6 +26,11 @@ def save_heatmap_as_png(self, data, output_path, title=None, xlabel="Position",
if title:
plt.title(title)

target_dir = os.path.basename(output_path)
if not os.path.isdir(target_dir):
print(f"Creating directory {target_dir}")
os.mkdir(target_dir)

plt.savefig(output_path, format='png')
plt.close(fig)

Expand All @@ -48,5 +54,22 @@ def save_plot_as_png(self, x, y, output_path, title=None, xlabel="Episodes", yla
if title:
plt.title(title)

target_dir = os.path.basename(output_path)
if not os.path.isdir(target_dir):
print(f"Creating directory {target_dir}")
os.mkdir(target_dir)

plt.savefig(output_path, format='png')
plt.close(fig)

def save_network(self, network, output_path):
target_dir = os.path.basename(output_path)
if not os.path.isdir(target_dir):
print(f"Creating directory {target_dir}")
os.mkdir(target_dir)

def save_qtable(self, qtable, output_path):
target_dir = os.path.basename(output_path)
if not os.path.isdir(target_dir):
print(f"Creating directory {target_dir}")
os.mkdir(target_dir)

0 comments on commit f12dfac

Please sign in to comment.