Skip to content

Commit

Permalink
Merge pull request #2502 from Trusted-AI/dev_1.18.2
Browse files Browse the repository at this point in the history
Update to ART 1.18.2
  • Loading branch information
beat-buesser authored Oct 2, 2024
2 parents 1207d0a + 8738a5a commit ba01c04
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import logging
import math
from packaging.version import parse
from typing import Any, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -121,8 +122,8 @@ def __init__(
import torch
import torchvision

torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torch_version = list(parse(torch.__version__.lower()).release)
torchvision_version = list(parse(torchvision.__version__.lower()).release)
assert (
torch_version[0] >= 1 and torch_version[1] >= 7 or (torch_version[0] >= 2)
), "AdversarialPatchPyTorch requires torch>=1.7.0"
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/pixel_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import logging
from itertools import product
from packaging.version import parse
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -42,7 +43,7 @@
import scipy
from scipy._lib._util import check_random_state

scipy_version = list(map(int, scipy.__version__.lower().split(".")))
scipy_version = list(parse(scipy.__version__.lower()).release)
if scipy_version[1] >= 8:
from scipy.optimize._optimize import _status_message
else:
Expand Down
5 changes: 3 additions & 2 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import logging
from packaging.version import parse
from typing import Any, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -96,8 +97,8 @@ def __init__(
import torch
import torchvision

torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split(".")))
torch_version = list(parse(torch.__version__.lower()).release)
torchvision_version = list(parse(torchvision.__version__.lower()).release)
assert not (torch_version[0] == 1 and (torch_version[1] == 8 or torch_version[1] == 9)), (
"PyTorchObjectDetector does not support torch==1.8 and torch==1.9 because of "
"https://github.com/pytorch/vision/issues/4153. Support will return for torch==1.10."
Expand Down

0 comments on commit ba01c04

Please sign in to comment.