From 3b3ddd86d17732699e34fb46a67510e2c7282b6d Mon Sep 17 00:00:00 2001 From: snehankekre Date: Tue, 1 Feb 2022 16:21:36 +0530 Subject: [PATCH] Bump version from 0.0.1 to 0.0.2 --- setup.py | 4 +-- streamlit_shap/__init__.py | 73 +++++++++++++++++++++++++++++++++----- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 671eede..0b811a8 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="streamlit-shap", - version="0.0.1", + version="0.0.2", author="Snehan Kekre", author_email="snehan@streamlit.io", description="Streamlit component for SHAP", @@ -13,5 +13,5 @@ include_package_data=True, classifiers=[], python_requires=">=3.6", - install_requires=["streamlit >= 1.4.0", "shap >= 0.4.0"], + install_requires=["streamlit >= 1.0.0", "shap >= 0.4.0"], ) \ No newline at end of file diff --git a/streamlit_shap/__init__.py b/streamlit_shap/__init__.py index 1da560a..4f73412 100644 --- a/streamlit_shap/__init__.py +++ b/streamlit_shap/__init__.py @@ -8,28 +8,85 @@ import base64 from io import BytesIO + def st_shap(plot, height=None): + """Takes a SHAP plot as input, and returns a streamlit.delta_generator.DeltaGenerator as output. + + It is recommended to set the height and omit the width + parameter to have the plot fit to the window. + + Parameters + ---------- + plot : None or matplotlib.figure.Figure or SHAP plot object + The SHAP plot object. + height: int or None + The height of the plot in pixels. + Returns + ------- + streamlit.delta_generator.DeltaGenerator + A SHAP plot as a streamlit.delta_generator.DeltaGenerator object. + """ + + # Plots such as waterfall and bar have no return value + # They create a new figure and call plt.show() if plot is None: - fig = plt.gcf() - ax = plt.gca() + # Test whether there is currently a Figure on the pyplot figure stack + # A Figure exists if the shap plot called plt.show() + if plt.get_fignums(): + fig = plt.gcf() + ax = plt.gca() + plt.tight_layout() + + # Save it to a temporary buffer + buf = BytesIO() + fig.savefig(buf, format="png") + fig_width, fig_height = fig.get_size_inches() * fig.dpi + + # Embed the result in the HTML output + data = base64.b64encode(buf.getbuffer()).decode("ascii") + html_str = f"" + + # Enable pyplot to properly clean up the memory + plt.cla() + plt.close(fig) + + fig = components.html(html_str, height=fig_height, width=fig_width) + else: + fig = components.html( + "

[Error] No plot to display. Received object of type <class 'NoneType'>.

" + ) + + # SHAP plots return a matplotlib.figure.Figure object when passed show=False as an argument + elif isinstance(plot, Figure): + fig = plot + plt.tight_layout() + + # Save it to a temporary buffer buf = BytesIO() fig.savefig(buf, format="png") - fig_width, fig_height = fig.get_size_inches()*fig.dpi + fig_width, fig_height = fig.get_size_inches() * fig.dpi + # Embed the result in the HTML output data = base64.b64encode(buf.getbuffer()).decode("ascii") html_str = f"" + + # Enable pyplot to properly clean up the memory plt.cla() plt.close(fig) fig = components.html(html_str, height=fig_height, width=fig_width) - elif isinstance(plot, Figure): - fig = plot - - else: + # SHAP plots containing JS/HTML have one or more of the following callable attributes + elif hasattr(plot, "html") or hasattr(plot, "data") or hasattr(plot, "matplotlib"): + shap_html = f"{shap.getjs()}{plot.html()}" fig = components.html(shap_html, height=height) - + + else: + fig = components.html( + "

[Error] No plot to display. Unable to understand input.

" + ) + return fig