Skip to content

Commit

Permalink
Merge pull request #3 from explosion/reset-stream
Browse files Browse the repository at this point in the history
Add `--allow-reset` to recipes
  • Loading branch information
koaning authored Dec 13, 2023
2 parents 9033eb8 + 4e9a022 commit 33d52f7
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 59 deletions.
15 changes: 11 additions & 4 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v3
with:
repository: explosion/prodigy
ref: v1.14.0
ref: v1.14.11
path: ./prodigy
ssh-key: ${{ secrets.GHA_PRODIGY_READ }}

Expand All @@ -34,7 +34,8 @@ jobs:
run: |
pip install --upgrade pip
pip install -e .
pip install ruff pytest
pip install ruff pytest playwright
playwright install
- name: Run help
if: always()
Expand All @@ -47,7 +48,13 @@ jobs:
shell: bash
run: python -m ruff prodigy_lunr tests

- name: Run pytest
- name: Run pytest unit tests
if: always()
shell: bash
run: python -m pytest tests
run: python -m pytest tests -m "not e2e" -vvv

- name: Run e2e tests
if: always()
shell: bash
run: python -m pytest tests -m "e2e" -vvv

135 changes: 95 additions & 40 deletions prodigy_lunr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tempfile import NamedTemporaryFile
import spacy
from pathlib import Path
from typing import Optional

Expand All @@ -8,8 +8,8 @@
from prodigy.recipes.textcat import manual as textcat_manual
from prodigy.recipes.ner import manual as ner_manual
from prodigy.recipes.spans import manual as spans_manual
from lunr import lunr
from lunr.index import Index
from prodigy.util import log
from .util import SearchIndex, JS, CSS, HTML, stream_reset_calback


@recipe(
Expand All @@ -20,13 +20,12 @@
# fmt: on
)
def index(source: Path, index_path: Path):
"""Builds an HSNWLIB index on example text data."""
"""Builds a LUNR index on example text data."""
# Store sentences as a list, not perfect, but works.
documents = [{"idx": i, **ex} for i, ex in enumerate(srsly.read_jsonl(source))]
# Create the index
index = lunr(ref='idx', fields=('text',), documents=documents)
# Store it on disk
srsly.write_gzip_json(index_path, index.serialize(), indent=0)
log("RECIPE: Calling `lunr.text.index`")
index = SearchIndex(source, index_path=index_path)
index.build_index()
index.store_index(index_path)


@recipe(
Expand All @@ -35,28 +34,19 @@ def index(source: Path, index_path: Path):
source=("Path to text source that has been indexed", "positional", None, str),
index_path=("Path to index", "positional", None, Path),
out_path=("Path to write examples into", "positional", None, Path),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
n=("Max number of results to return", "option", "n", int),
# fmt: on
)
def fetch(source: Path, index_path: Path, out_path: Path, query:str, n:int=200):
"""Fetch a relevant subset using a HNSWlib index."""
"""Fetch a relevant subset using a LUNR index."""
log("RECIPE: Calling `lunr.text.fetch`")
if not query:
raise ValueError("must pass query")

documents = [{"idx": i, **ex} for i, ex in enumerate(srsly.read_jsonl(source))]
index = Index.load(srsly.read_gzip_json(index_path))
results = index.search(query)[:n]

def to_prodigy_examples(results):
for res in results:
ex = documents[int(res['ref'])]
ex['meta'] = {
'score': res['score'], 'query': query
}
yield ex

srsly.write_jsonl(out_path, to_prodigy_examples(results))
index = SearchIndex(source, index_path=index_path)
new_examples = index.new_stream(query=query, n=n)
srsly.write_jsonl(out_path, new_examples)


@recipe(
Expand All @@ -66,8 +56,10 @@ def to_prodigy_examples(results):
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
exclusive=("Labels are exclusive", "flag", "e", bool),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def textcat_lunr_manual(
Expand All @@ -76,13 +68,30 @@ def textcat_lunr_manual(
index_path: Path,
labels:str,
query:str,
exclusive:bool = False
exclusive:bool = False,
n:int = 200,
allow_reset: bool = False
):
"""Run textcat.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return textcat_manual(dataset, stream, label=labels.split(","), exclusive=exclusive)
log("RECIPE: Calling `textcat.lunr.manual`")
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)
components = textcat_manual(dataset, stream, label=labels.split(","), exclusive=exclusive)

# Only update the components if the user wants to allow the user to reset the stream
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components


@recipe(
Expand All @@ -93,8 +102,10 @@ def textcat_lunr_manual(
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
patterns=("Path to match patterns file", "option", "pt", Path),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def ner_lunr_manual(
Expand All @@ -105,12 +116,33 @@ def ner_lunr_manual(
labels:str,
query:str,
patterns: Optional[Path] = None,
n:int = 200,
allow_reset:bool = False,
):
"""Run ner.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return ner_manual(dataset, nlp, stream, label=labels, patterns=patterns)
log("RECIPE: Calling `ner.lunr.manual`")
if "blank" in nlp:
spacy_mod = spacy.blank(nlp.replace("blank:", ""))
else:
spacy_mod = spacy.load(nlp)
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)

# Only update the components if the user wants to allow the user to reset the stream
components = ner_manual(dataset, spacy_mod, stream, label=labels.split(","), patterns=patterns)
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components


@recipe(
Expand All @@ -121,8 +153,10 @@ def ner_lunr_manual(
examples=("Examples that have been indexed", "positional", None, str),
index_path=("Path to trained index", "positional", None, Path),
labels=("Comma seperated labels to use", "option", "l", str),
query=("ANN query to run", "option", "q", str),
query=("Query to run", "option", "q", str),
patterns=("Path to match patterns file", "option", "pt", Path),
n=("Number of items to retreive via query", "option", "n", int),
allow_reset=("Allow the user to restart the query", "flag", "r", bool)
# fmt: on
)
def spans_lunr_manual(
Expand All @@ -133,9 +167,30 @@ def spans_lunr_manual(
labels:str,
query:str,
patterns: Optional[Path] = None,
n:int = 200,
allow_reset: bool = False
):
"""Run spans.manual using a query to populate the stream."""
with NamedTemporaryFile(suffix=".jsonl") as tmpfile:
fetch(examples, index_path, out_path=tmpfile.name, query=query)
stream = list(srsly.read_jsonl(tmpfile.name))
return spans_manual(dataset, nlp, stream, label=labels, patterns=patterns)
log("RECIPE: Calling `spans.lunr.manual`")
if "blank" in nlp:
spacy_mod = spacy.blank(nlp.replace("blank:", ""))
else:
spacy_mod = spacy.load(nlp)
index = SearchIndex(source=examples, index_path=index_path)
stream = index.new_stream(query, n=n)

# Only update the components if the user wants to allow the user to reset the stream
components = spans_manual(dataset, spacy_mod, stream, label=labels.split(","), patterns=patterns)
if allow_reset:
blocks = [
{"view_id": components["view_id"]},
{"view_id": "html", "html_template": HTML}
]
components["event_hooks"] = {
"stream-reset": stream_reset_calback(index, n=n)
}
components["view_id"] = "blocks"
components["config"]["javascript"] = JS
components["config"]["global_css"] = CSS
components["config"]["blocks"] = blocks
return components
140 changes: 140 additions & 0 deletions prodigy_lunr/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import srsly
from pathlib import Path
from typing import List, Optional, Dict
import textwrap
from lunr import lunr
from lunr.index import Index
from prodigy.util import set_hashes
from prodigy.util import log
from prodigy.components.stream import Stream
from prodigy.components.stream import get_stream
from prodigy.core import Controller

HTML = """
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.1.2/css/all.min.css"
integrity="sha512-1sCRPdkRXhBV2PBLUdRb4tMg1w2YPf37qatUFeS7zlBy7jJI8Lf4VHwWfZZfpXtYSLy85pkm9GaYVYMfw5BC1A=="
crossorigin="anonymous"
referrerpolicy="no-referrer"
/>
<details>
<summary id="reset">Reset stream?</summary>
<div class="prodigy-content">
<label class="label" for="query">New query:</label>
<input class="prodigy-text-input text-input" type="text" id="query" name="query" value="">
<br><br>
<button id="refreshButton" onclick="refreshData()">
Refresh Stream
<i
id="loadingIcon"
class="fa-solid fa-spinner fa-spin"
style="display: none;"
></i>
</button>
</div>
</details>
"""

# We need to dedent in order to prevent a bunch of whitespaces to appear.
HTML = textwrap.dedent(HTML).replace("\n", "")

CSS = """
.inner-div{
border: 1px solid #ddd;
text-align: left;
border-radius: 4px;
}
.label{
top: -3px;
opacity: 0.75;
position: relative;
font-size: 12px;
font-weight: bold;
padding-left: 10px;
}
.text-input{
width: 100%;
border: 1px solid #cacaca;
border-radius: 5px;
padding: 10px;
font-size: 20px;
background: transparent;
font-family: "Lato", "Trebuchet MS", Roboto, Helvetica, Arial, sans-serif;
}
#reset{
font-size: 16px;
}
"""

JS = """
function refreshData() {
document.querySelector('#loadingIcon').style.display = 'inline-block'
event_data = {
query: document.getElementById("query").value
}
window.prodigy
.event('stream-reset', event_data)
.then(updated_example => {
console.log('Updating Current Example with new data:', updated_example)
window.prodigy.resetQueue();
window.prodigy.update(updated_example)
document.querySelector('#loadingIcon').style.display = 'none'
})
.catch(err => {
console.error('Error in Event Handler:', err)
})
}
"""

def add_hashes(examples):
for ex in examples:
yield set_hashes(ex)


class SearchIndex:
def __init__(self, source: Path, index_path: Optional[Path] = None):
log(f"INDEX: Using {index_path=} and source={str(source)}.")
stream = get_stream(source)
stream.apply(add_hashes)
# Storing this as a list isn't scale-able, but is fair enough for medium sized datasets.
self.documents = [ex for ex in stream]
self.index_path = index_path
self.index = None
if self.index_path and self.index_path.exists():
self.index = Index.load(srsly.read_gzip_json(index_path))

def build_index(self) -> "SearchIndex":
# Store sentences as a list, not perfect, but works.
documents = [{"idx": i, 'text': ex['text']} for i, ex in enumerate(self.documents)]
# Create the index
self.index = lunr(ref='idx', fields=('text',), documents=documents)
return self

def store_index(self, path: Path):
srsly.write_gzip_json(str(self.index_path), self.index.serialize(), indent=0)
log(f"INDEX: Index file stored at {path}.")

def _to_prodigy_examples(self, examples: List[Dict], query:str):
for res in examples:
ex = self.documents[int(res['ref'])]
ex['meta'] = {
'score': res['score'], 'query': query, "index_ref": int(res['ref'])
}
yield set_hashes(ex)

def new_stream(self, query:str, n:int=100):
log(f"INDEX: Creating new stream of {n} examples using {query=}.")
results = self.index.search(query)[:n]
return self._to_prodigy_examples(results, query=query)


def stream_reset_calback(index_obj: SearchIndex, n:int=100):
def stream_reset(ctrl: Controller, *, query: str):
new_stream = Stream.from_iterable(index_obj.new_stream(query, n=n))
ctrl.reset_stream(new_stream, prepend_old_wrappers=True)
return next(ctrl.stream)
return stream_reset
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 0.1.1
version = 0.2.0
description = Recipes for finding interesting subsets using Lunr
url = https://github.com/explosion/prodigy-lunr
author = Explosion
Expand Down
Binary file added tests/datasets/index.gz.json
Binary file not shown.
Loading

0 comments on commit 33d52f7

Please sign in to comment.