Skip to content

Commit

Permalink
Bump version from 0.0.1 to 0.0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
snehankekre committed Feb 1, 2022
1 parent 05f4188 commit 3b3ddd8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setuptools.setup(
name="streamlit-shap",
version="0.0.1",
version="0.0.2",
author="Snehan Kekre",
author_email="[email protected]",
description="Streamlit component for SHAP",
Expand All @@ -13,5 +13,5 @@
include_package_data=True,
classifiers=[],
python_requires=">=3.6",
install_requires=["streamlit >= 1.4.0", "shap >= 0.4.0"],
install_requires=["streamlit >= 1.0.0", "shap >= 0.4.0"],
)
73 changes: 65 additions & 8 deletions streamlit_shap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,85 @@
import base64
from io import BytesIO


def st_shap(plot, height=None):
"""Takes a SHAP plot as input, and returns a streamlit.delta_generator.DeltaGenerator as output.
It is recommended to set the height and omit the width
parameter to have the plot fit to the window.
Parameters
----------
plot : None or matplotlib.figure.Figure or SHAP plot object
The SHAP plot object.
height: int or None
The height of the plot in pixels.
Returns
-------
streamlit.delta_generator.DeltaGenerator
A SHAP plot as a streamlit.delta_generator.DeltaGenerator object.
"""

# Plots such as waterfall and bar have no return value
# They create a new figure and call plt.show()
if plot is None:
fig = plt.gcf()
ax = plt.gca()

# Test whether there is currently a Figure on the pyplot figure stack
# A Figure exists if the shap plot called plt.show()
if plt.get_fignums():
fig = plt.gcf()
ax = plt.gca()
plt.tight_layout()

# Save it to a temporary buffer
buf = BytesIO()
fig.savefig(buf, format="png")
fig_width, fig_height = fig.get_size_inches() * fig.dpi

# Embed the result in the HTML output
data = base64.b64encode(buf.getbuffer()).decode("ascii")
html_str = f"<img src='data:image/png;base64,{data}'/>"

# Enable pyplot to properly clean up the memory
plt.cla()
plt.close(fig)

fig = components.html(html_str, height=fig_height, width=fig_width)
else:
fig = components.html(
"<p>[Error] No plot to display. Received object of type &lt;class 'NoneType'&gt;.</p>"
)

# SHAP plots return a matplotlib.figure.Figure object when passed show=False as an argument
elif isinstance(plot, Figure):
fig = plot
plt.tight_layout()

# Save it to a temporary buffer
buf = BytesIO()
fig.savefig(buf, format="png")
fig_width, fig_height = fig.get_size_inches()*fig.dpi
fig_width, fig_height = fig.get_size_inches() * fig.dpi

# Embed the result in the HTML output
data = base64.b64encode(buf.getbuffer()).decode("ascii")
html_str = f"<img src='data:image/png;base64,{data}'/>"

# Enable pyplot to properly clean up the memory
plt.cla()
plt.close(fig)

fig = components.html(html_str, height=fig_height, width=fig_width)

elif isinstance(plot, Figure):
fig = plot

else:
# SHAP plots containing JS/HTML have one or more of the following callable attributes
elif hasattr(plot, "html") or hasattr(plot, "data") or hasattr(plot, "matplotlib"):

shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
fig = components.html(shap_html, height=height)


else:
fig = components.html(
"<p>[Error] No plot to display. Unable to understand input.</p>"
)

return fig

0 comments on commit 3b3ddd8

Please sign in to comment.