diff --git a/tests/test_hf.py b/tests/test_hf.py index c45949a..8a4da78 100644 --- a/tests/test_hf.py +++ b/tests/test_hf.py @@ -11,10 +11,16 @@ hub_id = "ultralyticsplus/yolov8s" +# for ultralytics < 8.0.44 def test_load_from_hub(): path = download_from_hub(hub_id) +# for ultralytics >= 8.0.44 +def test_load_from_hub_yolo_8_0_44(): + model = YOLO("keremberke/yolov8n-table-extraction") + + def test_yolo_from_hub(): model = YOLO(hub_id) diff --git a/ultralyticsplus/__init__.py b/ultralyticsplus/__init__.py index 9af3d16..817dc16 100644 --- a/ultralyticsplus/__init__.py +++ b/ultralyticsplus/__init__.py @@ -1,4 +1,4 @@ from .hf_utils import download_from_hub, push_to_hfhub from .ultralytics_utils import YOLO, postprocess_classify_output, render_result -__version__ = "0.0.29" +__version__ = "0.1.0" diff --git a/ultralyticsplus/ultralytics_utils.py b/ultralyticsplus/ultralytics_utils.py index fa3def5..acaef38 100644 --- a/ultralyticsplus/ultralytics_utils.py +++ b/ultralyticsplus/ultralytics_utils.py @@ -83,12 +83,28 @@ def _load_from_hf_hub(self, weights: str, hf_token=None): self.task = self.model.args["task"] self.overrides = self.model.args self._reset_ckpt_args(self.overrides) - ( - self.ModelClass, - self.TrainerClass, - self.ValidatorClass, - self.PredictorClass, - ) = self._assign_ops_from_task() + + # for loading a model with ultralytics <8.0.44 + if hasattr(self, "_assign_ops_from_task"): + ( + self.ModelClass, + self.TrainerClass, + self.ValidatorClass, + self.PredictorClass, + ) = self._assign_ops_from_task() + + # for loading a model with ultralytics >=8.0.44 + else: + if self.task not in self.task_map: + raise ValueError( + f"Task '{self.task}' not supported. Supported tasks: {list(self.task_map.keys())}" + ) + ( + self.ModelClass, + self.TrainerClass, + self.ValidatorClass, + self.PredictorClass, + ) = self.task_map[self.task] def render_result(