diff --git a/fdtd/visualization.py b/fdtd/visualization.py index adffc42..6e939f7 100644 --- a/fdtd/visualization.py +++ b/fdtd/visualization.py @@ -37,11 +37,12 @@ def visualize( srccolor="C0", detcolor="C2", norm="linear", - show=False, # default False to allow animate to be true animate=False, # True to see frame by frame states of grid while running simulation index=None, # index for each frame of animation (visualize fn runs in a loop, loop variable is passed as index) save=False, # True to save frames (requires parameters index, folder) folder=None, # folder path to save frames + show=False, # default False to allow animate to be true + style=None, ): """visualize a projection of the grid and the optical energy inside the grid @@ -61,7 +62,10 @@ def visualize( index: index for each frame of animation (typically a loop variable is passed) save: save frames in a folder folder: path to folder to save frames + style: Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle". """ + if style is not None: + plt.style.use(style) if norm not in ("linear", "lin", "log"): raise ValueError("Color map normalization should be 'linear' or 'log'.") # imports (placed here to circumvent circular imports) @@ -116,7 +120,7 @@ def visualize( plt.plot([], lw=3, color=detcolor, label="Detectors") # Grid energy - grid_energy = bd.sum(grid.E ** 2 + grid.H ** 2, -1) + grid_energy = bd.sum(grid.E**2 + grid.H**2, -1) if x is not None: assert grid.Ny > 1 and grid.Nz > 1 xlabel, ylabel = "y", "z" @@ -309,7 +313,9 @@ def visualize( cmap_norm = None if norm == "log": cmap_norm = LogNorm(vmin=1e-4, vmax=grid_energy.max() + 1e-4) - plt.imshow(abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm) + plt.imshow( + abs(bd.numpy(grid_energy)), cmap=cmap, interpolation="sinc", norm=cmap_norm + ) # finalize the plot plt.ylabel(xlabel) @@ -327,8 +333,12 @@ def visualize( if show: plt.show() + return plt.gcf() # return figure for gradio support + -def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): +def dB_map_2D( + block_det=None, choose_axis=2, interpolation="spline16", show=True, style=None +): """ Displays detector readings from an 'fdtd.BlockDetector' in a decibel map spanning a 2D slice region inside the BlockDetector. Compatible with continuous sources (not pulse). @@ -338,6 +348,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): block_det (numpy array): 5 axes numpy array (timestep, row, column, height, {x, y, z} parameter) created by 'fdtd.BlockDetector'. (optional) choose_axis (int): Choose between {0, 1, 2} to display {x, y, z} data. Default 2 (-> z). (optional) interpolation (string): Preferred 'matplotlib.pyplot.imshow' interpolation. Default "spline16". + show (bool): automatically call plt.show at the end of the plotting function + style (string): Matplotlib style sheet to use for plotting. e.g. "https://raw.githubusercontent.com/dracula/matplotlib/master/dracula.mplstyle". """ if block_det is None: raise ValueError( @@ -349,7 +361,8 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): ) # TODO: convert all 2D slices (y-z, x-z plots) into x-y plot data structure - + if style is not None: + plt.style.use(style) plt.ioff() plt.close() a = [] # array to store wave intensities @@ -360,24 +373,28 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): a[i].append(max(temp) - min(temp)) peakVal, minVal = max(map(max, a)), min(map(min, a)) - #print( + # print( # "Peak at:", # [ # [[i, j] for j, y in enumerate(x) if y == peakVal] # for i, x in enumerate(a) # if peakVal in x # ], - #) + # ) a = 10 * log10([[y / minVal for y in x] for x in a]) plt.title("dB map of Electrical waves in detector region") plt.imshow(a, cmap="inferno", interpolation=interpolation) cbar = plt.colorbar() cbar.ax.set_ylabel("dB scale", rotation=270) - plt.show() + if show: + plt.show() + + return plt.gcf() -def plot_detection(detector_dict=None, specific_plot=None): + +def plot_detection(detector_dict=None, specific_plot=None, show=True, style=None): """ 1. Plots intensity readings on array of 'fdtd.LineDetector' as a function of timestep. 2. Plots time of arrival of pulse at different LineDetector in array. @@ -387,6 +404,8 @@ def plot_detection(detector_dict=None, specific_plot=None): detector_dict (dictionary): Dictionary of detector readings, as created by 'fdtd.Grid.save_data()'. (optional) specific_plot (string): Plot for a specific axis data. Choose from {"Ex", "Ey", "Ez", "Hx", "Hy", "Hz"}. """ + if style is not None: + plt.style.use(style) if detector_dict is None: raise Exception( "Function plotDetection() requires a dictionary of detector readings as 'detector_dict' parameter." @@ -464,7 +483,10 @@ def plot_detection(detector_dict=None, specific_plot=None): plt.xlabel("Time of arrival (time steps)") plt.legend() plt.suptitle("Time-of-arrival plot") - plt.show() + if show: + plt.show() + + return plt.gcf() #