Skip to content

Commit

Permalink
Add support for Literal to Python <=3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt committed Sep 16, 2024
1 parent c681a21 commit 17162fe
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
11 changes: 10 additions & 1 deletion python/cog/code_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import types
from typing import Optional, Set, Union

COG_IMPORT_MODULES = {"cog", "typing", "sys", "os", "functools", "pydantic", "numpy"}
COG_IMPORT_MODULES = {
"cog",
"typing",
"typing_extensions",
"sys",
"os",
"functools",
"pydantic",
"numpy",
}


def load_module_from_string(
Expand Down
19 changes: 14 additions & 5 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ def get_value(node: ast.AST) -> "AstVal":
return [get_value(e) for e in node.elts]
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
if isinstance(node, ast.Index):
return node.value # type: ignore
raise ValueError("Unexpected node type", type(node))


Expand All @@ -357,11 +359,18 @@ def get_annotation(node: "ast.AST | None") -> str:
return node.value # e.g. arg: "Path"
if isinstance(node, ast.Subscript):
value = get_annotation(node.value)
# Literal is unsupported in Python 3.8 and earlier
if value == "Literal" and sys.version_info > (3, 8):
elts = (
node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice]
)
if value == "Literal":
if sys.version_info < (3, 9):
if isinstance(node.slice, ast.Index):
elts = [node.slice.value]
else:
elts = node.slice.elts
else:
elts = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
)
return f"Literal[{','.join(repr(get_value(e)) for e in elts)}]"
# ignore other Subscript (Optional[str]), BinOp (str | int), and stuff like that
raise ValueError("Unexpected annotation type", type(node))
Expand Down
5 changes: 2 additions & 3 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
cast,
)

from typing_extensions import Literal # Python 3.7

try:
from typing import get_args, get_origin
from typing import Literal, get_args, get_origin
except ImportError: # Python < 3.8
from typing_compat import get_args, get_origin # type: ignore
from typing_extensions import Literal

from unittest.mock import patch

Expand Down
5 changes: 2 additions & 3 deletions python/tests/server/test_http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import io
import sys
import time
import unittest.mock as mock

Expand Down Expand Up @@ -375,8 +374,8 @@ def test_train_openapi_specification(client):


@pytest.mark.skipif(
not (PYDANTIC_V2 and sys.version_info >= (3, 9)),
reason="Literal is used for enums in Pydantic v2 and Python 3.9+",
not PYDANTIC_V2,
reason="Literal is used for enums only in Pydantic v2",
)
@uses_predictor("input_literal")
def test_openapi_specification_with_literal(client, static_schema):
Expand Down
9 changes: 4 additions & 5 deletions python/tests/server/test_http_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import os
import sys
import threading
import time

Expand Down Expand Up @@ -224,8 +223,8 @@ def test_choices_int(client):


@pytest.mark.skipif(
not (PYDANTIC_V2 and sys.version_info >= (3, 9)),
reason="Literal is used for enums in Pydantic v2 and Python 3.9+",
not PYDANTIC_V2,
reason="Literal is used for enums only in Pydantic v2",
)
@uses_predictor("input_literal")
def test_literal_str(client):
Expand All @@ -236,8 +235,8 @@ def test_literal_str(client):


@pytest.mark.skipif(
not (PYDANTIC_V2 and sys.version_info >= (3, 9)),
reason="Literal is used for enums in Pydantic v2 and Python 3.9+",
not PYDANTIC_V2,
reason="Literal is used for enums only in Pydantic v2",
)
@uses_predictor("input_literal_integer")
def test_literal_int(client):
Expand Down

0 comments on commit 17162fe

Please sign in to comment.