diff --git a/pyomo/core/base/block.py b/pyomo/core/base/block.py index a0948c693d7..2918ef78b00 100644 --- a/pyomo/core/base/block.py +++ b/pyomo/core/base/block.py @@ -9,6 +9,7 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from __future__ import annotations import copy import logging import sys @@ -21,6 +22,7 @@ from io import StringIO from itertools import filterfalse, chain from operator import itemgetter, attrgetter +from typing import Union, Any, Type from pyomo.common.autoslots import AutoSlots from pyomo.common.collections import Mapping @@ -44,6 +46,7 @@ from pyomo.core.base.indexed_component import ( ActiveIndexedComponent, UnindexedComponent_set, + IndexedComponent, ) from pyomo.opt.base import ProblemFormat, guess_format @@ -539,7 +542,7 @@ def __init__(self, component): super(_BlockData, self).__setattr__('_decl_order', []) self._private_data = None - def __getattr__(self, val): + def __getattr__(self, val) -> Union[Component, IndexedComponent, Any]: if val in ModelComponentFactory: return _component_decorator(self, ModelComponentFactory.get_class(val)) # Since the base classes don't support getattr, we can just @@ -548,7 +551,7 @@ def __getattr__(self, val): "'%s' object has no attribute '%s'" % (self.__class__.__name__, val) ) - def __setattr__(self, name, val): + def __setattr__(self, name: str, val: Union[Component, IndexedComponent, Any]): """ Set an attribute of a block data object. """ @@ -2007,6 +2010,17 @@ class Block(ActiveIndexedComponent): _ComponentDataClass = _BlockData _private_data_initializers = defaultdict(lambda: dict) + @overload + def __new__( + cls: Type[Block], *args, **kwds + ) -> Union[ScalarBlock, IndexedBlock]: ... + + @overload + def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ... + + @overload + def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ... + def __new__(cls, *args, **kwds): if cls != Block: return super(Block, cls).__new__(cls) @@ -2251,6 +2265,11 @@ class IndexedBlock(Block): def __init__(self, *args, **kwds): Block.__init__(self, *args, **kwds) + @overload + def __getitem__(self, index) -> _BlockData: ... + + __getitem__ = IndexedComponent.__getitem__ # type: ignore + # # Deprecated functions. diff --git a/pyomo/core/base/constraint.py b/pyomo/core/base/constraint.py index 8cf3c48ad0a..fde1160e563 100644 --- a/pyomo/core/base/constraint.py +++ b/pyomo/core/base/constraint.py @@ -9,10 +9,12 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from __future__ import annotations import sys import logging from weakref import ref as weakref_ref from pyomo.common.pyomo_typing import overload +from typing import Union, Type from pyomo.common.deprecation import RenamedClass from pyomo.common.errors import DeveloperError @@ -42,6 +44,7 @@ ActiveIndexedComponent, UnindexedComponent_set, rule_wrapper, + IndexedComponent, ) from pyomo.core.base.set import Set from pyomo.core.base.disable_methods import disable_methods @@ -728,6 +731,17 @@ class Infeasible(object): Violated = Infeasible Satisfied = Feasible + @overload + def __new__( + cls: Type[Constraint], *args, **kwds + ) -> Union[ScalarConstraint, IndexedConstraint]: ... + + @overload + def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ... + + @overload + def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ... + def __new__(cls, *args, **kwds): if cls != Constraint: return super(Constraint, cls).__new__(cls) @@ -1020,6 +1034,11 @@ def add(self, index, expr): """Add a constraint with a given index.""" return self.__setitem__(index, expr) + @overload + def __getitem__(self, index) -> _GeneralConstraintData: ... + + __getitem__ = IndexedComponent.__getitem__ # type: ignore + @ModelComponentFactory.register("A list of constraint expressions.") class ConstraintList(IndexedConstraint): diff --git a/pyomo/core/base/indexed_component.py b/pyomo/core/base/indexed_component.py index 0d498da091d..e1be613d666 100644 --- a/pyomo/core/base/indexed_component.py +++ b/pyomo/core/base/indexed_component.py @@ -18,7 +18,7 @@ import pyomo.core.base as BASE from pyomo.core.base.indexed_component_slice import IndexedComponent_slice from pyomo.core.base.initializer import Initializer -from pyomo.core.base.component import Component, ActiveComponent +from pyomo.core.base.component import Component, ActiveComponent, ComponentData from pyomo.core.base.config import PyomoOptions from pyomo.core.base.enums import SortComponents from pyomo.core.base.global_set import UnindexedComponent_set @@ -606,7 +606,7 @@ def iteritems(self): """Return a list (index,data) tuples from the dictionary""" return self.items() - def __getitem__(self, index): + def __getitem__(self, index) -> ComponentData: """ This method returns the data corresponding to the given index. """ diff --git a/pyomo/core/base/param.py b/pyomo/core/base/param.py index 3ef33b9ee45..5fcaf92b25a 100644 --- a/pyomo/core/base/param.py +++ b/pyomo/core/base/param.py @@ -9,11 +9,13 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from __future__ import annotations import sys import types import logging from weakref import ref as weakref_ref from pyomo.common.pyomo_typing import overload +from typing import Union, Type from pyomo.common.autoslots import AutoSlots from pyomo.common.deprecation import deprecation_warning, RenamedClass @@ -291,6 +293,17 @@ class NoValue(object): pass + @overload + def __new__( + cls: Type[Param], *args, **kwds + ) -> Union[ScalarParam, IndexedParam]: ... + + @overload + def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ... + + @overload + def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ... + def __new__(cls, *args, **kwds): if cls != Param: return super(Param, cls).__new__(cls) @@ -983,7 +996,7 @@ def _create_objects_for_deepcopy(self, memo, component_list): # between potentially variable GetItemExpression objects and # "constant" GetItemExpression objects. That will need to wait for # the expression rework [JDS; Nov 22]. - def __getitem__(self, args): + def __getitem__(self, args) -> _ParamData: try: return super().__getitem__(args) except: diff --git a/pyomo/core/base/set.py b/pyomo/core/base/set.py index 2dc14460911..b3277ab3260 100644 --- a/pyomo/core/base/set.py +++ b/pyomo/core/base/set.py @@ -9,6 +9,7 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from __future__ import annotations import inspect import itertools import logging @@ -16,6 +17,8 @@ import sys import weakref from pyomo.common.pyomo_typing import overload +from typing import Union, Type, Any as typingAny +from collections.abc import Iterator from pyomo.common.collections import ComponentSet from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass @@ -569,7 +572,7 @@ def isordered(self): def subsets(self, expand_all_set_operators=None): return iter((self,)) - def __iter__(self): + def __iter__(self) -> Iterator[typingAny]: """Iterate over the set members Raises AttributeError for non-finite sets. This must be @@ -1967,6 +1970,12 @@ class SortedOrder(object): _ValidOrderedAuguments = {True, False, InsertionOrder, SortedOrder} _UnorderedInitializers = {set} + @overload + def __new__(cls: Type[Set], *args, **kwds) -> Union[_SetData, IndexedSet]: ... + + @overload + def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ... + def __new__(cls, *args, **kwds): if cls is not Set: return super(Set, cls).__new__(cls) @@ -2373,6 +2382,11 @@ def data(self): "Return a dict containing the data() of each Set in this IndexedSet" return {k: v.data() for k, v in self.items()} + @overload + def __getitem__(self, index) -> _SetData: ... + + __getitem__ = IndexedComponent.__getitem__ # type: ignore + class FiniteScalarSet(_FiniteSetData, Set): def __init__(self, **kwds): diff --git a/pyomo/core/base/var.py b/pyomo/core/base/var.py index f426c9c4f55..856a2dc0237 100644 --- a/pyomo/core/base/var.py +++ b/pyomo/core/base/var.py @@ -9,10 +9,12 @@ # This software is distributed under the 3-clause BSD License. # ___________________________________________________________________________ +from __future__ import annotations import logging import sys from pyomo.common.pyomo_typing import overload from weakref import ref as weakref_ref +from typing import Union, Type from pyomo.common.deprecation import RenamedClass from pyomo.common.log import is_debug_set @@ -668,6 +670,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin): _ComponentDataClass = _GeneralVarData + @overload + def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ... + + @overload + def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ... + + @overload + def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ... + def __new__(cls, *args, **kwargs): if cls is not Var: return super(Var, cls).__new__(cls) @@ -688,7 +699,7 @@ def __init__( dense=True, units=None, name=None, - doc=None + doc=None, ): ... def __init__(self, *args, **kwargs): @@ -1046,7 +1057,7 @@ def domain(self, domain): # between potentially variable GetItemExpression objects and # "constant" GetItemExpression objects. That will need to wait for # the expression rework [JDS; Nov 22]. - def __getitem__(self, args): + def __getitem__(self, args) -> _GeneralVarData: try: return super().__getitem__(args) except RuntimeError: