Skip to content

Commit

Permalink
Better colors for NER
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed Aug 9, 2020
1 parent db9ab1e commit 3d790ae
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 24 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,9 @@ Model checkpoint is now saved for all epochs again.

- This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG.

[0.46.5]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2cc77f7...HEAD
[0.47.0]: https://github.com/ThilinaRajapakse/simpletransformers/compare/d405b4a...HEAD

[0.46.5]: https://github.com/ThilinaRajapakse/simpletransformers/compare/2cc77f7...d405b4a

[0.46.3]: https://github.com/ThilinaRajapakse/simpletransformers/compare/7f37cb7...2cc77f7

Expand Down
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@ Supports

**New documentation is now live at [simpletransformers.ai](https://simpletransformers.ai/)**

Currently added:

- Text classification
- NER
- Question answering
- Language model training
- Language model fine-tuning
- Training language models from scratch

Any feedback will be immensely helpful in improving the documentation! If you have any feedback, please leave a comment in the [issue](https://github.com/ThilinaRajapakse/simpletransformers/issues/342) I've opened for this.


Expand Down
12 changes: 10 additions & 2 deletions simpletransformers/streamlit/classification_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
from scipy.special import softmax

from simpletransformers.streamlit.streamlit_utils import get
from simpletransformers.classification import ClassificationModel, MultiLabelClassificationModel
from simpletransformers.streamlit.streamlit_utils import get, simple_transformers_model


def get_states(model, session_state=None):
Expand All @@ -27,6 +28,13 @@ def get_states(model, session_state=None):
return session_state, model


@st.cache(hash_funcs={ClassificationModel: simple_transformers_model, MultiLabelClassificationModel: simple_transformers_model})
def get_prediction(model, input_text):
prediction, raw_values = model.predict([input_text])

return prediction, raw_values


def classification_viewer(model, model_class):
st.subheader("Enter text: ")
input_text = st.text_area("")
Expand Down Expand Up @@ -73,7 +81,7 @@ def classification_viewer(model, model_class):
)

if input_text:
prediction, raw_values = model.predict([input_text])
prediction, raw_values = get_prediction(model, input_text)
raw_values = [list(np.squeeze(raw_values))]

if model.args.sliding_window and isinstance(raw_values[0][0], np.ndarray):
Expand Down
17 changes: 12 additions & 5 deletions simpletransformers/streamlit/ner_view.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import streamlit as st
import pandas as pd

from simpletransformers.streamlit.streamlit_utils import get
from simpletransformers.ner import NERModel
from simpletransformers.streamlit.streamlit_utils import get, simple_transformers_model, get_color


ENTITY_WRAPPER = (
"""<mark style="background: {}; border-radius: 0.25rem; padding: 0.25rem; display: inline-block">{} {}</mark>"""
"""<mark style="background: rgba{}; font-weight: 450; border-radius: 0.5rem; margin: 0.1em; padding: 0.25rem; display: inline-block">{} {}</mark>"""
)
ENTITY_LABEL_WRAPPER = """<span style="background: #fff; font-size: 0.56em; font-weight: bold; padding: 0.3em 0.3em; vertical-align: middle; margin: 0 0 0.15rem 0.5rem; line-height: 1; display: inline-block">{}</span>"""

Expand All @@ -17,6 +18,13 @@ def format_word(word, entity, entity_checkboxes, entity_color_map):
return word


@st.cache(hash_funcs={NERModel: simple_transformers_model})
def get_prediction(model, input_text):
predictions, _ = model.predict([input_text])

return predictions


def ner_viewer(model):
session_state = get(
max_seq_length=model.args.max_seq_length,
Expand All @@ -27,7 +35,7 @@ def ner_viewer(model):

st.sidebar.subheader("Entities")
entity_checkboxes = {entity: st.sidebar.checkbox(entity, value=True) for entity in entity_list}
entity_color_map = {entity: "#a6e22d" for entity in entity_list}
entity_color_map = {entity: get_color(i) for i, entity in enumerate(entity_list)}

st.sidebar.subheader("Parameters")
model.args.max_seq_length = st.sidebar.slider(
Expand All @@ -37,8 +45,7 @@ def ner_viewer(model):
st.subheader("Enter text: ")
input_text = st.text_area("")

predictions, _ = model.predict([input_text])
prediction = predictions[0]
prediction = get_prediction(model, input_text)[0]

to_write = " ".join([format_word(word, entity, entity_checkboxes, entity_color_map) for pred in prediction for word, entity in pred.items()])

Expand Down
21 changes: 16 additions & 5 deletions simpletransformers/streamlit/qa_view.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import streamlit as st
import pandas as pd

from simpletransformers.streamlit.streamlit_utils import get
from simpletransformers.question_answering import QuestionAnsweringModel
from simpletransformers.streamlit.streamlit_utils import get, simple_transformers_model


QA_ANSWER_WRAPPER = """{} <span style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 0.25rem; background: #a6e22d">{}</span> {}"""
Expand All @@ -25,6 +26,15 @@ def get_states(model, session_state=None):
return session_state, model


@st.cache(hash_funcs={QuestionAnsweringModel: simple_transformers_model})
def get_prediction(model, context_text, question_text):
to_predict = [{"context": context_text, "qas": [{"id": 0, "question": question_text}]}]

answers, probabilities = model.predict(to_predict)

return answers, probabilities


def qa_viewer(model):
st.sidebar.subheader("Parameters")
try:
Expand Down Expand Up @@ -58,17 +68,18 @@ def qa_viewer(model):
question_text = st.text_area("", key="question")

if context_text and question_text:
to_predict = [{"context": context_text, "qas": [{"id": 0, "question": question_text}]}]

answers, probabilities = model.predict(to_predict)
answers, probabilities = get_prediction(model, context_text, question_text)

st.subheader(f"Predictions")
answers = answers[0]["answer"]

context_pieces = context_text.split(answers[0])

if answers[0] != "empty":
st.write(QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True)
if len(context_pieces) == 2:
st.write(QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True)
else:
st.write(QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], answers[0].join(context_pieces[1:])), unsafe_allow_html=True)
else:
st.write(QA_EMPTY_ANSWER_WRAPPER.format("", answers[0], ""), unsafe_allow_html=True)

Expand Down
4 changes: 2 additions & 2 deletions simpletransformers/streamlit/simple_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def streamlit_runner(
if manual_model:
st.sidebar.subheader("Model Details")
fill_info = st.empty()
fill_info.markdown("Please fill the Model details on the sidebar.")
model_class = st.sidebar.selectbox("Model Class", list(model_class_map.keys()))
fill_info.markdown("Please fill the Model details on the sidebar and click `Load Model`.")
model_class = st.sidebar.selectbox("Simple Transformers task", list(model_class_map.keys()))
model_type = st.sidebar.text_input("Model type (e.g. bert, roberta, xlnet)")
model_name = st.sidebar.text_input("Model name (e.g. bert-base-cased, roberta-base)")

Expand Down
15 changes: 15 additions & 0 deletions simpletransformers/streamlit/streamlit_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

import streamlit as st
import collections
import functools
Expand Down Expand Up @@ -158,3 +160,16 @@ def get(**kwargs):
this_session._custom_session_state = SessionState(**kwargs)

return this_session._custom_session_state


def simple_transformers_model(model):
return (type(model).__name__, model.args)


def get_color(i):
# Colors taken from Sasha Trubetskoy's list of colors - https://sashamaps.net/docs/tools/20-colors/
colors = [(60, 180, 75, 0.4), (255, 225, 25, 0.4), (0, 130, 200, 0.4), (245, 130, 48, 0.4), (145, 30, 180, 0.4), (70, 240, 240, 0.4), (240, 50, 230, 0.4), (210, 245, 60, 0.4), (250, 190, 212, 0.4), (0, 128, 128, 0.4), (220, 190, 255, 0.4), (170, 110, 40, 0.4), (255, 250, 200, 0.4), (128, 0, 0, 0.4), (170, 255, 195, 0.4), (128, 128, 0, 0.4), (255, 215, 180, 0.4), (0, 0, 128, 0.4), (128, 128, 128, 0.4), (255, 255, 255, 0.4), (0, 0, 0, 0.4), (230, 25, 75, 0.4)]
try:
return str(colors[i])
except IndexError:
return str(tuple(np.random.rand(3,).tolist() + 0.7))

0 comments on commit 3d790ae

Please sign in to comment.