diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py index 646b933052..2b80c694cf 100644 --- a/python/cog/command/ast_openapi_schema.py +++ b/python/cog/command/ast_openapi_schema.py @@ -357,19 +357,11 @@ def get_annotation(node: "ast.AST | None") -> str: return node.value # e.g. arg: "Path" if isinstance(node, ast.Subscript): value = get_annotation(node.value) - if value == "Literal": - if sys.version_info < (3, 9): - elts = ( - node.slice.elts - if isinstance(node.slice, ast.Tuple) - else [node.slice] - ) - else: - elts = ( - node.slice.elts - if isinstance(node.slice, ast.Tuple) - else [node.slice] - ) + # Literal is unsupported in Python 3.8 and earlier + if value == "Literal" and sys.version_info > (3, 8): + elts = ( + node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + ) return f"Literal[{','.join(repr(get_value(e)) for e in elts)}]" # ignore other Subscript (Optional[str]), BinOp (str | int), and stuff like that raise ValueError("Unexpected annotation type", type(node)) diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 27a317ddfa..8903a2db2f 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -1,8 +1,10 @@ import base64 import io +import sys import time import unittest.mock as mock +import pytest import responses from PIL import Image from responses import matchers @@ -370,6 +372,9 @@ def test_train_openapi_specification(client): } +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier" +) @uses_predictor("input_literal") def test_openapi_specification_with_literal(client, static_schema): resp = client.get("/openapi.json") diff --git a/python/tests/server/test_http_input.py b/python/tests/server/test_http_input.py index d04b7f557e..73b9867a20 100644 --- a/python/tests/server/test_http_input.py +++ b/python/tests/server/test_http_input.py @@ -1,7 +1,9 @@ import base64 import os +import sys import threading +import pytest import responses from cog import schema @@ -196,6 +198,9 @@ def test_choices_int(client): assert resp.status_code == 422 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier" +) @uses_predictor("input_literal") def test_literal_str(client): resp = client.post("/predictions", json={"input": {"text": "foo"}}) @@ -204,6 +209,9 @@ def test_literal_str(client): assert resp.status_code == 422 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Literal is unsupported in Python 3.8 and earlier" +) @uses_predictor("input_literal_integer") def test_literal_int(client): resp = client.post("/predictions", json={"input": {"x": 1}})