diff --git a/wntr/graphics/network.py b/wntr/graphics/network.py index 10a96618..42282d10 100644 --- a/wntr/graphics/network.py +++ b/wntr/graphics/network.py @@ -58,7 +58,7 @@ def plot_network_gis( node_size=20, node_range=None, node_alpha=1, node_cmap=None, node_labels=False, link_width=1, link_range=None, link_alpha=1, link_cmap=None, link_labels=False, add_colorbar=True, node_colorbar_label='Node', link_colorbar_label='Link', - directed=False, ax=None, filename=None): + directed=False, ax=None, show_plot=True, filename=None): """ Plot network graphic @@ -137,6 +137,9 @@ def plot_network_gis( ax: matplotlib axes object, optional Axes for plotting (None indicates that a new figure with a single axes will be used) + + show_plot: bool, optional + If True, show plot with plt.show() filename : str, optional Filename used to save the figure @@ -153,10 +156,22 @@ def plot_network_gis( ax.set_title(title) # set aspect setting - aspect = None + # aspect = None + aspect = "auto" + # aspect = "equal" + # initialize gis objects wn_gis = wn.to_gis() + link_gdf = pd.concat((wn_gis.pipes, wn_gis.pumps, wn_gis.valves)) + + node_gdf = pd.concat((wn_gis.junctions, wn_gis.tanks, wn_gis.reservoirs)) + + # missing keyword args + # these are used for elements that do not have a value for the link_attribute + missing_kwds = {"color": "black"} + + # colormap if link_cmap is None: link_cmap = plt.get_cmap('Spectral_r') @@ -171,36 +186,36 @@ def plot_network_gis( # prepare pipe plotting keywords - pipes_kwds = {} + link_kwds = {} if link_attribute is not None: - pipes_kwds["column"] = link_attribute - pipes_kwds["cmap"] = link_cmap + link_kwds["column"] = link_attribute + link_kwds["cmap"] = link_cmap if add_colorbar: - pipes_kwds["legend"] = True + link_kwds["legend"] = True else: - pipes_kwds["color"] = "black" - pipes_kwds["alpha"] = link_alpha + link_kwds["color"] = "black" + link_kwds["alpha"] = link_alpha - pipes_cbar_kwds = {} - pipes_cbar_kwds["shrink"] = 0.5 - pipes_cbar_kwds["pad"] = 0.0 - pipes_cbar_kwds["label"] = link_colorbar_label + link_cbar_kwds = {} + link_cbar_kwds["shrink"] = 0.5 + link_cbar_kwds["pad"] = 0.0 + link_cbar_kwds["label"] = link_colorbar_label # prepare junctin plotting keywords - junction_kwds = {} + node_kwds = {} if node_attribute is not None: - junction_kwds["column"] = node_attribute - junction_kwds["cmap"] = node_cmap + node_kwds["column"] = node_attribute + node_kwds["cmap"] = node_cmap if add_colorbar: - junction_kwds["legend"] = True + node_kwds["legend"] = True else: - junction_kwds["color"] = "black" - junction_kwds["alpha"] = node_alpha + node_kwds["color"] = "black" + node_kwds["alpha"] = node_alpha - junction_cbar_kwds = {} - junction_cbar_kwds["shrink"] = 0.5 - junction_cbar_kwds["pad"] = 0.0 - junction_cbar_kwds["label"] = node_colorbar_label + node_cbar_kwds = {} + node_cbar_kwds["shrink"] = 0.5 + node_cbar_kwds["pad"] = 0.0 + node_cbar_kwds["label"] = node_colorbar_label # TODO handle node/link labels @@ -208,48 +223,42 @@ def plot_network_gis( # plot junctions - wn_gis.junctions.plot( - ax=ax, aspect=aspect, markersize=node_size, zorder=1, - vmax=node_range[0], vmin=node_range[1],legend_kwds=junction_cbar_kwds, **junction_kwds) + # node_gdf.plot( + # ax=ax, aspect=aspect, markersize=node_size, zorder=1, + # vmax=node_range[0], vmin=node_range[1], legend_kwds=node_cbar_kwds, **node_kwds) - # plot tanks - wn_gis.tanks.plot(ax=ax, marker="P", aspect=aspect, zorder=1) + # # plot tanks + # wn_gis.tanks.plot(ax=ax, marker="P", aspect=aspect, zorder=1) - # plot reservoirs - wn_gis.reservoirs.plot(ax=ax, marker="s", aspect=aspect, zorder=1) + # # plot reservoirs + # wn_gis.reservoirs.plot(ax=ax, marker="s", aspect=aspect, zorder=1) # plot pipes - wn_gis.pipes.plot( + minx, miny, maxx, maxy = link_gdf.total_bounds + link_gdf.plot( ax=ax, aspect=aspect, zorder=0, linewidth=link_width, - vmax=link_range[0], vmin=link_range[1], legend_kwds=pipes_cbar_kwds, **pipes_kwds) - - # plot pumps - if len(wn_gis.pumps) >0: - wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect) - wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True) - wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1) - # valve_midpoints.plot(ax=ax, marker=">", aspect=aspect) - for idx , row in wn_gis.pumps.iterrows(): - x,y = row["midpoint"].x, row["midpoint"].y - # dx = math.cos(math.radians(row["angle"])) - # dy = math.sin(math.radians(row["angle"])) - angle = row["angle"] - ax.scatter(x,y, color="purple", s=100, marker=(3,0, angle-90)) - # ax.arrow(x,y, dx*0.1, dy*0.1, head_width=0.05, head_length=0.1, fc="blue", ec="blue") + vmax=link_range[0], vmin=link_range[1], missing_kwds=missing_kwds, legend_kwds=link_cbar_kwds, **link_kwds) + ax.set_xlim([minx, maxx]) + ax.set_ylim([miny, maxy]) + # # plot pumps + # if len(wn_gis.pumps) >0: + # wn_gis.pumps.plot(ax=ax, color="purple", aspect=aspect) + # wn_gis.pumps["midpoint"] = wn_gis.pumps.geometry.interpolate(0.5, normalized=True) + # wn_gis.pumps["angle"] = wn_gis.pumps.apply(lambda row: _get_angle(row.geometry), axis=1) + # for idx , row in wn_gis.pumps.iterrows(): + # x,y = row["midpoint"].x, row["midpoint"].y + # angle = row["angle"] + # ax.scatter(x,y, color="purple", s=100, marker=(3,0, angle-90)) # plot valves - if len(wn_gis.valves) >0: - wn_gis.valves.plot(ax=ax, color="green", aspect=aspect) - wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True) - wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1) - # valve_midpoints.plot(ax=ax, marker=">", aspect=aspect) - for idx , row in wn_gis.valves.iterrows(): - x,y = row["midpoint"].x, row["midpoint"].y - # dx = math.cos(math.radians(row["angle"])) - # dy = math.sin(math.radians(row["angle"])) - angle = row["angle"] - ax.scatter(x,y, color="green", s=100, marker=(3,0, angle-90)) - # ax.arrow(x,y, dx*0.1, dy*0.1, head_width=0.05, head_length=0.1, fc="blue", ec="blue") + # if len(wn_gis.valves) >0: + # # wn_gis.valves.plot(ax=ax, color="green", aspect=aspect) + # wn_gis.valves["midpoint"] = wn_gis.valves.geometry.interpolate(0.5, normalized=True) + # wn_gis.valves["angle"] = wn_gis.valves.apply(lambda row: _get_angle(row.geometry), axis=1) + # for idx , row in wn_gis.valves.iterrows(): + # x,y = row["midpoint"].x, row["midpoint"].y + # angle = row["angle"] + # ax.scatter(x,y, color="green", s=100, marker=(3,0, angle-90)) # annotation if node_labels: @@ -261,7 +270,7 @@ def plot_network_gis( for x, y, label in zip(midpoints.geometry.x, midpoints.geometry.y, wn_gis.pipes.index): ax.annotate(label, xy=(x, y))#, xytext=(3, 3),)# textcoords="offset points") - ax.axis('off') + # ax.axis('off') if filename: plt.savefig(filename)