Skip to content

Commit

Permalink
Add support for Pydantic 2 (#1858)
Browse files Browse the repository at this point in the history
* Update pydantic and fastapi requirements

Signed-off-by: Mattt Zmuda <[email protected]>

* Define PYDANTIC_V2

Signed-off-by: Mattt Zmuda <[email protected]>

* Replace use of ErrorWrapper removed in pydantic v2

Signed-off-by: Mattt Zmuda <[email protected]>

* Update predictor model config and output types

Signed-off-by: Mattt Zmuda <[email protected]>

* Pass string literal for extra value

* Set explicit None values for optional fields

Signed-off-by: Mattt Zmuda <[email protected]>

* Give output field an explicit default of None

Signed-off-by: Mattt Zmuda <[email protected]>

* Import types from typing instead of accessing through namespace

Signed-off-by: Mattt Zmuda <[email protected]>

* Set extra in model config

Set default description for enumerations

Signed-off-by: Mattt Zmuda <[email protected]>

* Actually, don't allow additional properties

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix generated schema for webhook event filter

Signed-off-by: Mattt Zmuda <[email protected]>

* Fixup generated OpenAPI schema

Signed-off-by: Mattt Zmuda <[email protected]>

* Update expectations for test_gt_lt response JSON

Signed-off-by: Mattt Zmuda <[email protected]>

* Flatten allOf for component schemas

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix validation of weights_url

Signed-off-by: Mattt Zmuda <[email protected]>

* Use __get_pydantic_json_schema__ instead of __modify_schema__

* Fix test_make_encodeable_ignores_files

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix expectations for test_bad_int_input

Signed-off-by: Mattt Zmuda <[email protected]>

* Skip numpy tests for pydantic v2

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix test_make_encodeable_encodes_pydantic_models for pydantic v2

Signed-off-by: Mattt Zmuda <[email protected]>

* Fix use of __get_pydantic_core_schema__

* Fix warnings for pydantic.Field construction

* Fix use of deprecated dict instead of model_dump

* Limit application of OpenAPI processing for default enum description

* Passing test_openapi_specification_with_int_choices

* Fix declarations of __get_pydantic_json_schema__ and __get_pydantic_core_schema__

* Update validate_input_type to handle annotated cog.Types

* Update __get_pydantic_core_schema__ to start with is_instance_schema

* Conditionalize use of deprecated dict method

* Avoid re-encoding of PredictionResponse in _predict

* Remove no-op if TYPE_CHECKING

* Unwrap SerializationIterators after model_dump

* Also unwrap SerializationIterator in make_encodeable

* Replace use of choices=[] with typing.Literal

* Make output schema optional so validation doesn't fail on errors

* Restore choices tests and fix validation of choices in v2

* Refactor, move, and rename unwrap_pydantic_serialization_iterators to helpers module

* Update AST OpenAPI generator to support Literal

* Fix ast.Str and ast.Num deprecation warnings

* Disable Literal support for Python 3.8

* Apply suggestions from code review

Co-authored-by: Nick Stenning <[email protected]>
Signed-off-by: Mattt <[email protected]>

* Document unwrap_pydantic_serialization_iterators

* Test on Pydantic v1 and v2

Adds a Pydantic axis to the test matrix

* Fix import of Literal for Python 3.7

* Reenable and fix test_numpy

* Customize openapi_schema for Pydantic v1 and v2 to get consistent spec version of 3.1.0

Extract openapi manipulation functions to helpers

* Conditionalize call to unwrap_pydantic_serialization_iterators

* Correctly unset title for webhook_events_filter

* Skip literal tests for Python<3.9 and Pydantic<2

* Remove pyupgrade checks for ruff

* Fix pyright warning in useragent

* Manually remove webhook_events_filter title from OpenAPI schema

* Add support for Literal to Python <=3.9

* Run typecheck with Pydantic v1 and Pydantic v2

* Replace use of deprecated dict method with model_dump

* Attempt to resolve test failures related to OpenAPI generation

* Fix pydantic{1,2} typecheck

Copy the code to {envtmpdir} for pydantic1, since there's no way
to override pyright configuration temporarily.

* Fix type errors on pydantic1

* Install pydantic<2 in the cog wheel installation step.

Installing pydantic>2 won't easily downgrade when the user packages
require it, but upgrading works fine.

---------

Signed-off-by: Mattt Zmuda <[email protected]>
Signed-off-by: Mattt <[email protected]>
Signed-off-by: Morgan Fainberg <[email protected]>
Co-authored-by: Yorick van Pelt <[email protected]>
Co-authored-by: Nick Stenning <[email protected]>
Co-authored-by: Morgan Fainberg <[email protected]>
  • Loading branch information
4 people authored Oct 9, 2024
1 parent 135599f commit a6219fa
Show file tree
Hide file tree
Showing 24 changed files with 854 additions and 231 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,18 @@ jobs:
run: make test-go

test-python:
name: "Test Python ${{ matrix.python-version }}"
name: "Test Python ${{ matrix.python-version }} + Pydantic v${{ matrix.pydantic }}"
needs: build-python
runs-on: ubuntu-latest-8-cores
strategy:
fail-fast: false
matrix:
pydantic: ["1", "2"]
python-version: ${{ fromJson(needs.build-python.outputs.python-versions) }}
exclude:
# Pydantic 2 is not supported on Python 3.7
- pydantic: "2"
python-version: "3.7"
steps:
- name: Download pre-built packages
uses: actions/download-artifact@v4
Expand All @@ -111,7 +116,7 @@ jobs:
- name: Remove src to ensure tests run against wheel
run: rm -rf python/cog
- name: Test
run: python -Im tox run --installpkg "$COG_WHEEL" -e ${{ env.TOX_PYTHON }}-tests
run: python -Im tox run --installpkg "$COG_WHEEL" -e ${{ env.TOX_PYTHON }}-pydantic${{ matrix.pydantic }}-tests

# cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150
test-integration:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ check-fmt:
.PHONY: lint
lint: $(COG_EMBEDDED_WHEEL) $(COG_WHEEL) check-fmt vet
$(GO) run github.com/golangci/golangci-lint/cmd/golangci-lint run ./...
$(TOX) run --installpkg $(COG_WHEEL) -e lint,typecheck
$(TOX) run --installpkg $(COG_WHEEL) -e lint,typecheck-pydantic2

.PHONY: run-docs-server
run-docs-server:
Expand Down
4 changes: 3 additions & 1 deletion pkg/dockerfile/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,9 @@ func (g *Generator) installCog() (string, error) {
if err != nil {
return "", err
}
pipInstallLine := fmt.Sprintf("RUN --mount=type=cache,target=/root/.cache/pip pip install -t /dep %s", containerPath)
// Install pydantic<2 for now, installing pydantic>2 wouldn't allow a downgrade later,
// but upgrading works fine
pipInstallLine := fmt.Sprintf("RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir -t /dep %s 'pydantic<2'", containerPath)
if g.strip {
pipInstallLine += " && " + StripDebugSymbolsCommand
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/dockerfile/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func testInstallCog(relativeTmpDir string, stripped bool) string {
}
return fmt.Sprintf(`COPY %s/%s /tmp/%s
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip pip install -t /dep /tmp/%s%s
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir -t /dep /tmp/%s 'pydantic<2'%s
ENV CFLAGS=`, relativeTmpDir, wheel, wheel, wheel, strippedCall)
}

Expand Down
17 changes: 6 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ requires-python = ">=3.7"
dependencies = [
# intentionally loose. perhaps these should be vendored to not collide with user code?
"attrs>=20.1,<24",
"fastapi>=0.75.2,<0.99.0",
"pydantic>=1.9,<2",
"fastapi>=0.75.2,<0.111.0",
"pydantic>=1.9,<3",
"PyYAML",
"requests>=2,<3",
"structlog>=20,<25",
Expand All @@ -36,12 +36,7 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
dev = [
"build",
"setuptools_scm",
"tox",
"tox-uv",
]
dev = ["build", "setuptools_scm", "tox", "tox-uv"]

tests = [
"httpx",
Expand Down Expand Up @@ -76,6 +71,9 @@ reportUnneesssaryContains = "warning"
reportMissingTypeArgument = "error"
reportUnusedExpression = "warning"

[tool.pyright.defineConstant]
PYDANTIC_V2 = true

[tool.setuptools]
include-package-data = false

Expand Down Expand Up @@ -105,7 +103,6 @@ lint.select = [
"F", # Pyflakes
"I", # isort
"W", # pycodestyle warning
"UP", # pyupgrade
"S", # flake8-bandit
"B", # flake8-bugbear
"ANN", # flake8-annotations
Expand All @@ -121,8 +118,6 @@ lint.ignore = [
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
"ANN401", # Dynamically typed expressions are disallowed
# recently ruff added checks for pyupgrade, which is bad for back-compat
"UP037", # Remove quotes from type annotation
]
extend-exclude = [
"python/tests/server/fixtures/*",
Expand Down
11 changes: 10 additions & 1 deletion python/cog/code_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import types
from typing import Optional, Set, Union

COG_IMPORT_MODULES = {"cog", "typing", "sys", "os", "functools", "pydantic", "numpy"}
COG_IMPORT_MODULES = {
"cog",
"typing",
"typing_extensions",
"sys",
"os",
"functools",
"pydantic",
"numpy",
}


def load_module_from_string(
Expand Down
55 changes: 45 additions & 10 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,18 @@ def get_value(node: ast.AST) -> "AstVal":
"""Return the value of constant or list of constants"""
if isinstance(node, ast.Constant):
return node.value
# for python3.7, were deprecated for Constant
if isinstance(node, (ast.Str, ast.Bytes)):
return node.s
if isinstance(node, ast.Num):
return node.n
# DeprecationWarning: ast.Str | ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead
if sys.version_info < (3, 8):
if isinstance(node, (ast.Str, ast.Bytes)):
return node.s
if isinstance(node, ast.Num):
return node.n
if isinstance(node, (ast.List, ast.Tuple)):
return [get_value(e) for e in node.elts]
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
if isinstance(node, ast.Index):
return node.value # type: ignore
raise ValueError("Unexpected node type", type(node))


Expand All @@ -354,8 +357,22 @@ def get_annotation(node: "ast.AST | None") -> str:
return node.id
if isinstance(node, ast.Constant):
return node.value # e.g. arg: "Path"
# ignore Subscript (Optional[str]), BinOp (str | int), and stuff like that
# except we may need to care about list/List[str]
if isinstance(node, ast.Subscript):
value = get_annotation(node.value)
if value == "Literal":
if sys.version_info < (3, 9):
if isinstance(node.slice, ast.Index):
elts = [node.slice.value]
else:
elts = node.slice.elts
else:
elts = (
node.slice.elts
if isinstance(node.slice, ast.Tuple)
else [node.slice]
)
return f"Literal[{','.join(repr(get_value(e)) for e in elts)}]"
# ignore other Subscript (Optional[str]), BinOp (str | int), and stuff like that
raise ValueError("Unexpected annotation type", type(node))


Expand Down Expand Up @@ -519,17 +536,31 @@ def extract_info(code: str) -> "JSONDict": # pylint: disable=too-many-branches,
msg = "unknown argument for Input"
raise ValueError(msg)
kws[kw.arg] = to_serializable(get_value(kw.value))
elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)):
elif isinstance(default, (ast.Constant, ast.List, ast.Tuple)) or (
# DeprecationWarning: ast.Str | ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead
sys.version_info < (3, 8) and isinstance(default, (ast.Str, ast.Num))
):
kws = {"default": to_serializable(get_value(default))} # could be None
elif default == ...: # no default
kws = {}
else:
raise ValueError("Unexpected default value", default)
input: JSONDict = {"x-order": len(properties)}
# need to handle other types?
arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string")
if get_annotation(arg.annotation) in ("Path", "File"):

annotation = get_annotation(arg.annotation)
arg_type = OPENAPI_TYPES.get(annotation, "string")
if annotation in ("Path", "File"):
input["format"] = "uri"
elif annotation.startswith("Literal["):
input["enum"] = list(
ast.literal_eval(annotation[7:]) # Safely eval the literal values
)
arg_type = (
OPENAPI_TYPES.get(type(input["enum"][0]).__name__, "string")
if input["enum"]
else "string"
)
for attr in KEPT_ATTRS:
if attr in kws:
input[attr] = kws[attr]
Expand All @@ -544,6 +575,10 @@ def extract_info(code: str) -> "JSONDict": # pylint: disable=too-many-branches,
"type": arg_type,
"description": "An enumeration.",
}
elif "enum" in input:
input["title"] = arg.arg.replace("_", " ").title()
input["type"] = arg_type
input["description"] = "An enumeration."
else:
input["title"] = arg.arg.replace("_", " ").title()
input["type"] = arg_type
Expand Down
7 changes: 5 additions & 2 deletions python/cog/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
except ImportError:
np = None

from .types import Path
from .types import PYDANTIC_V2, Path


def make_encodeable(obj: Any) -> Any: # pylint: disable=too-many-return-statements
Expand All @@ -26,7 +26,10 @@ def make_encodeable(obj: Any) -> Any: # pylint: disable=too-many-return-stateme
"""

if isinstance(obj, BaseModel):
return make_encodeable(obj.dict(exclude_unset=True))
if PYDANTIC_V2:
return make_encodeable(obj.model_dump(exclude_unset=True))
else:
return make_encodeable(obj.dict())
if isinstance(obj, dict):
return {key: make_encodeable(value) for key, value in obj.items()}
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
Expand Down
Loading

0 comments on commit a6219fa

Please sign in to comment.