Skip to content

Commit

Permalink
Disable Literal support for Python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Aug 12, 2024
1 parent d7cc128 commit bb064f8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
18 changes: 5 additions & 13 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,19 +357,11 @@ def get_annotation(node: "ast.AST | None") -> str:
return node.value # e.g. arg: "Path"
if isinstance(node, ast.Subscript):
value = get_annotation(node.value)
if value == "Literal":
if sys.version_info < (3, 9):
elts = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
)
else:
elts = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
)
# Literal is unsupported in Python 3.8 and earlier
if value == "Literal" and sys.version_info > (3, 8):
elts = (
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
)
return f"Literal[{','.join(repr(get_value(e)) for e in elts)}]"
# ignore other Subscript (Optional[str]), BinOp (str | int), and stuff like that
raise ValueError("Unexpected annotation type", type(node))
Expand Down
5 changes: 5 additions & 0 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import base64
import io
import sys
import time
import unittest.mock as mock

import pytest
import responses
from PIL import Image
from responses import matchers
Expand Down Expand Up @@ -370,6 +372,9 @@ def test_train_openapi_specification(client):
}


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier"
)
@uses_predictor("input_literal")
def test_openapi_specification_with_literal(client, static_schema):
resp = client.get("/openapi.json")
Expand Down
8 changes: 8 additions & 0 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import os
import sys
import threading

import pytest
import responses

from cog import schema
Expand Down Expand Up @@ -196,6 +198,9 @@ def test_choices_int(client):
assert resp.status_code == 422


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier"
)
@uses_predictor("input_literal")
def test_literal_str(client):
resp = client.post("/predictions", json={"input": {"text": "foo"}})
Expand All @@ -204,6 +209,9 @@ def test_literal_str(client):
assert resp.status_code == 422


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier"
)
@uses_predictor("input_literal_integer")
def test_literal_int(client):
resp = client.post("/predictions", json={"input": {"x": 1}})
Expand Down

0 comments on commit bb064f8

Please sign in to comment.