From cbaebb6058b2e932ae1c7ce38ef538748bdfd06b Mon Sep 17 00:00:00 2001 From: kbonney Date: Fri, 3 Nov 2023 11:15:02 -0400 Subject: [PATCH] reverting to returning fig rather than ax to be consistent with the rest of the repo --- pvops/tests/test_text.py | 4 ++-- pvops/text/visualize.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pvops/tests/test_text.py b/pvops/tests/test_text.py index 1fc39ec..a451bea 100644 --- a/pvops/tests/test_text.py +++ b/pvops/tests/test_text.py @@ -278,7 +278,7 @@ def test_visualize_attribute_connectivity(): om_col_dict = {"attribute1_col": "Attr1", "attribute2_col": "Attr2"} - ax, G = visualize.visualize_attribute_connectivity( + fig, G = visualize.visualize_attribute_connectivity( df, om_col_dict, figsize=(10, 8), @@ -289,7 +289,7 @@ def test_visualize_attribute_connectivity(): }, ) - assert isinstance(ax, matplotlib.axes.Axes) + assert isinstance(fig, matplotlib.pyplot.Figure) assert list(G.edges()) == [("A", "X"), ("B", "X"), ("C", "Y"), ("C", "Z")] matplotlib.pyplot.close() diff --git a/pvops/text/visualize.py b/pvops/text/visualize.py index a224057..117d4ec 100644 --- a/pvops/text/visualize.py +++ b/pvops/text/visualize.py @@ -28,7 +28,6 @@ def visualize_attribute_connectivity( attribute_colors=["lightgreen", "cornflowerblue"], edge_width_scalar=10, graph_aargs={}, - ax=None, ): """Visualize a knowledge graph which shows the frequency of combinations between attributes ``ATTRIBUTE1_COL`` and ``ATTRIBUTE2_COL`` @@ -64,17 +63,15 @@ def visualize_attribute_connectivity( - font_weight='bold' - node_size=19000 - font_size=35 - ax : axis - axis to draw on, defaults to None and will create a new figure and axis in this case. Returns ------- Matplotlib axis, networkx graph """ - if ax is None: # create a new figure - fig = plt.figure(facecolor='w', edgecolor='k') - ax = plt.gca() + # initialize figure + fig = plt.figure(figsize=figsize,facecolor='w', edgecolor='k') + ax = plt.gca() # attribute column names ATTRIBUTE1_COL = om_col_dict["attribute1_col"] @@ -136,7 +133,7 @@ def visualize_attribute_connectivity( plt.show(block=False) - return ax, G + return fig, G def visualize_attribute_timeseries(