Skip to content

Commit

Permalink
adding plots
Browse files Browse the repository at this point in the history
  • Loading branch information
workstation1 gpu committed May 13, 2022
1 parent c4acb45 commit cc8f334
Show file tree
Hide file tree
Showing 13 changed files with 555 additions and 420 deletions.
56 changes: 42 additions & 14 deletions cmvae/plot_cmvae_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,55 @@
import matplotlib.pyplot as plt
import pandas as pd


def plot_cmvae_error_progress(epochs, errors, Type='Training'):
def plot_cmvae_error_progress(data):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(epochs, errors[0], color='tab:blue')
ax.plot(epochs, errors[1], color='tab:orange')
ax.plot(epochs, errors[2], color='tab:red')
ax.set_title('cmvae ' + Type + ' losses')
plt.legend(["Img recon", "Gate recon", "kl"])
ax.plot(data[0], data[1], color='tab:blue')
ax.plot(data[0], data[4], color='tab:red')
ax.set_title('image reconstruction loss')
plt.legend(["train losses", "test losses"])
plt.xlabel("epoch")
plt.ylabel("img_rec_loss")
plt.savefig("../figs/cmvae_img_rec_losses.png")

fig2 = plt.figure()
ax = fig2.add_subplot(1, 1, 1)
zipped = zip(errors[0], errors[1], errors[2])
Sum = [x + y + z for (x, y, z) in zipped]
ax.plot(epochs, Sum, color='tab:blue')
ax.set_title('CMVAE Total ' + Type + ' loss')
ax.plot(data[0], data[2], color='tab:blue')
ax.plot(data[0], data[5], color='tab:red')
ax.set_title('gate reconstruction loss')
plt.legend(["train losses", "test losses"])
plt.xlabel("epoch")
plt.ylabel("gate_rec_loss")
plt.savefig("../figs/cmvae_gate_rec_losses.png")

fig3 = plt.figure()
ax = fig3.add_subplot(1, 1, 1)
ax.plot(data[0], data[3], color='tab:blue')
ax.plot(data[0], data[6], color='tab:red')
ax.set_title('K-L loss')
plt.legend(["train losses", "test losses"])
plt.xlabel("epoch")
plt.ylabel("kl_loss")
plt.savefig("../figs/cmvae_kl_losses.png")

fig4 = plt.figure()
ax = fig4.add_subplot(1, 1, 1)
zipped1 = zip(data[1], data[2], data[3])
Sum1 = [x + y + z for (x, y, z) in zipped1]
zipped2 = zip(data[4], data[5], data[6])
Sum2 = [x + y + z for (x, y, z) in zipped2]
ax.plot(data[0], Sum1, color='tab:blue')
ax.plot(data[0], Sum2, color='tab:red')
ax.set_title('CMVAE Total Loss')
plt.legend(["train losses", "test losses"])
plt.xlabel("epoch")
plt.ylabel("total_loss")
plt.savefig("../figs/cmvae_total_losses.png")

plt.show()


df = pd.read_csv('training_cmvae_losses.csv', delimiter=',').T
# User list comprehension to create a list of lists from Dataframe rows
data = [list(row) for row in df.values]
plot_cmvae_error_progress(data[0], data[1:4])
plot_cmvae_error_progress(data[0], data[4:7], Type="Testing")
plot_cmvae_error_progress(data)

Binary file removed cmvae_test_losses.png
Binary file not shown.
Binary file removed cmvae_total_test_loss.png
Binary file not shown.
Binary file removed cmvae_total_training_loss.png
Binary file not shown.
Binary file removed cmvae_training_losses.png
Binary file not shown.
Binary file added figs/bc_losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/cmvae_gate_rec_losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/cmvae_img_rec_losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/cmvae_kl_losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figs/cmvae_total_losses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit cc8f334

Please sign in to comment.