Skip to content

Commit

Permalink
Merge branch 'main' into austin/low-memory-lo
Browse files Browse the repository at this point in the history
  • Loading branch information
awalker4 authored Dec 8, 2023
2 parents 6b78984 + f306062 commit 7b0ae88
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
8 changes: 6 additions & 2 deletions prepline_general/api/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def pipeline_api(
elements = partition(**partition_kwargs)

except OSError as e:
if (
if isinstance(e.args[0], str) and (
"chipper-fast-fine-tuning is not a local folder" in e.args[0]
or "ved-fine-tuning is not a local folder" in e.args[0]
):
Expand All @@ -442,7 +442,11 @@ def pipeline_api(
detail="The Chipper model is not available for download. It can be accessed via the official hosted API.",
)

raise e
# OSError isn't caught by our top level handler, so convert it here
raise HTTPException(
status_code=500,
detail=str(e),
)
except ValueError as e:
if "Invalid file" in e.args[0]:
raise HTTPException(
Expand Down
45 changes: 45 additions & 0 deletions test_general/api/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,48 @@ def test_invalid_strategy_for_image_file():
)
assert resp.status_code == 400
assert "fast strategy is not available for image files" in resp.text


@pytest.mark.parametrize(
("exception", "status_code", "message"),
[
(
OSError("chipper-fast-fine-tuning is not a local folder"),
400,
"The Chipper model is not available for download. "
"It can be accessed via the official hosted API.",
),
(
OSError("ved-fine-tuning is not a local folder"),
400,
"The Chipper model is not available for download. "
"It can be accessed via the official hosted API.",
),
(OSError(1, "An error happened"), 500, "[Errno 1] An error happened"),
],
)
def test_chipper_not_available_errors(monkeypatch, mocker, exception, status_code, message):
"""
Assert that we return the right error if Chipper is not downloaded.
OSError can have an int as the first arg, do not blow up if that happens.
"""

mock_partition = Mock(side_effect=exception)

monkeypatch.setattr(
general,
"partition",
mock_partition,
)

client = TestClient(app)
test_file = Path("sample-docs") / "layout-parser-paper-fast.pdf"

resp = client.post(
MAIN_API_ROUTE,
files=[("files", (str(test_file), open(test_file, "rb"), "application/pdf"))],
data={"strategy": "hi_res", "hi_res_model_name": "chipper"},
)

assert resp.status_code == status_code
assert resp.json().get("detail") == message

0 comments on commit 7b0ae88

Please sign in to comment.