Skip to content

Commit

Permalink
Always return plot. Decide to show with flag. Default to style .
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Dec 17, 2023
1 parent f7f1db4 commit 8561dbb
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions fdtd/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ def visualize(
srccolor="C0",
detcolor="C2",
norm="linear",
show=False, # default False to allow animate to be true
animate=False, # True to see frame by frame states of grid while running simulation
index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index)
save=False, # True to save frames (requires parameters index, folder)
folder=None, # folder path to save frames
ret_plot=False, # True to return figure for gradio support
pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle",

show=False, # default False to allow animate to be true
style=None,
):
"""visualize a projection of the grid and the optical energy inside the grid
Expand All @@ -64,11 +62,10 @@ def visualize(
index: index for each frame of animation (typically a loop variable is passed)
save: save frames in a folder
folder: path to folder to save frames
ret_plot: return figure instead of showing it
pltstyle: Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
style: Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if ret_plot:
plt.style.use(pltstyle)
if style is not None:
plt.style.use(style)
if norm not in ("linear", "lin", "log"):
raise ValueError("Color map normalization should be 'linear' or 'log'.")
# imports (placed here to circumvent circular imports)
Expand Down Expand Up @@ -123,7 +120,7 @@ def visualize(
plt.plot([], lw=3, color=detcolor, label="Detectors")

# Grid energy
grid_energy = bd.sum(grid.E ** 2 + grid.H ** 2, -1)
grid_energy = bd.sum(grid.E**2 + grid.H**2, -1)
if x is not None:
assert grid.Ny > 1 and grid.Nz > 1
xlabel, ylabel = "y", "z"
Expand Down Expand Up @@ -316,7 +313,9 @@ def visualize(
cmap_norm = None
if norm == "log":
cmap_norm = LogNorm(vmin=1e-4, vmax=grid_energy.max() + 1e-4)
plt.imshow(abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm)
plt.imshow(
abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm
)

# finalize the plot
plt.ylabel(xlabel)
Expand All @@ -334,11 +333,12 @@ def visualize(
if show:
plt.show()

if ret_plot:
return plt.gcf() # return figure for gradio support
return plt.gcf() # return figure for gradio support


def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=False, pltstyle="https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle"):
def dB_map_2D(
block_det=None, choose_axis=2, interpolation="spline16", show=True, style=None
):
"""
Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector.
Compatible with continuous sources (not pulse).
Expand All @@ -348,8 +348,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=
block_det (numpy array): 5 axes numpy array (timestep, row, column, height, {x, y, z} parameter) created by 'fdtd.BlockDetector'.
(optional) choose_axis (int): Choose between {0, 1, 2} to display {x, y, z} data. Default 2 (-> z).
(optional) interpolation (string): Preferred 'matplotlib.pyplot.imshow' interpolation. Default "spline16".
ret_plot (bool): True to return figure instead of showing it.
pltstyle (string): Matplotlib style sheet to use for plotting. Default "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
show (bool): automatically call plt.show at the end of the plotting function
style (string): Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle".
"""
if block_det is None:
raise ValueError(
Expand All @@ -361,8 +361,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=
)

# TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure
if ret_plot:
plt.style.use(pltstyle)
if style is not None:
plt.style.use(style)
plt.ioff()
plt.close()
a = [] # array to store wave intensities
Expand All @@ -373,27 +373,28 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16", ret_plot=
a[i].append(max(temp) - min(temp))

peakVal, minVal = max(map(max, a)), min(map(min, a))
#print(
# print(
# "Peak at:",
# [
# [[i, j] for j, y in enumerate(x) if y == peakVal]
# for i, x in enumerate(a)
# if peakVal in x
# ],
#)
# )
a = 10 * log10([[y / minVal for y in x] for x in a])

plt.title("dB map of Electrical waves in detector region")
plt.imshow(a, cmap="inferno", interpolation=interpolation)
cbar = plt.colorbar()
cbar.ax.set_ylabel("dB scale", rotation=270)
if ret_plot:
return plt.gcf()
else:

if show:
plt.show()

return plt.gcf()


def plot_detection(detector_dict=None, specific_plot=None):
def plot_detection(detector_dict=None, specific_plot=None, show=True, style=None):
"""
1. Plots intensity readings on array of 'fdtd.LineDetector' as a function of timestep.
2. Plots time of arrival of pulse at different LineDetector in array.
Expand All @@ -403,6 +404,8 @@ def plot_detection(detector_dict=None, specific_plot=None):
detector_dict (dictionary): Dictionary of detector readings, as created by 'fdtd.Grid.save_data()'.
(optional) specific_plot (string): Plot for a specific axis data. Choose from {"Ex", "Ey", "Ez", "Hx", "Hy", "Hz"}.
"""
if style is not None:
plt.style.use(style)
if detector_dict is None:
raise Exception(
"Function plotDetection() requires a dictionary of detector readings as 'detector_dict' parameter."
Expand Down Expand Up @@ -480,7 +483,10 @@ def plot_detection(detector_dict=None, specific_plot=None):
plt.xlabel("Time of arrival (time steps)")
plt.legend()
plt.suptitle("Time-of-arrival plot")
plt.show()
if show:
plt.show()

return plt.gcf()


#
Expand Down

0 comments on commit 8561dbb

Please sign in to comment.