Skip to content

Commit

Permalink
Merge branch 'main' into scc-performance-2
Browse files Browse the repository at this point in the history
  • Loading branch information
Robbybp authored Mar 12, 2024
2 parents de294ee + b06ddea commit 4f1e20c
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 8 deletions.
23 changes: 21 additions & 2 deletions pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from __future__ import annotations
import copy
import logging
import sys
Expand All @@ -21,6 +22,7 @@
from io import StringIO
from itertools import filterfalse, chain
from operator import itemgetter, attrgetter
from typing import Union, Any, Type

from pyomo.common.autoslots import AutoSlots
from pyomo.common.collections import Mapping
Expand All @@ -44,6 +46,7 @@
from pyomo.core.base.indexed_component import (
ActiveIndexedComponent,
UnindexedComponent_set,
IndexedComponent,
)

from pyomo.opt.base import ProblemFormat, guess_format
Expand Down Expand Up @@ -539,7 +542,7 @@ def __init__(self, component):
super(_BlockData, self).__setattr__('_decl_order', [])
self._private_data = None

def __getattr__(self, val):
def __getattr__(self, val) -> Union[Component, IndexedComponent, Any]:
if val in ModelComponentFactory:
return _component_decorator(self, ModelComponentFactory.get_class(val))
# Since the base classes don't support getattr, we can just
Expand All @@ -548,7 +551,7 @@ def __getattr__(self, val):
"'%s' object has no attribute '%s'" % (self.__class__.__name__, val)
)

def __setattr__(self, name, val):
def __setattr__(self, name: str, val: Union[Component, IndexedComponent, Any]):
"""
Set an attribute of a block data object.
"""
Expand Down Expand Up @@ -2007,6 +2010,17 @@ class Block(ActiveIndexedComponent):
_ComponentDataClass = _BlockData
_private_data_initializers = defaultdict(lambda: dict)

@overload
def __new__(
cls: Type[Block], *args, **kwds
) -> Union[ScalarBlock, IndexedBlock]: ...

@overload
def __new__(cls: Type[ScalarBlock], *args, **kwds) -> ScalarBlock: ...

@overload
def __new__(cls: Type[IndexedBlock], *args, **kwds) -> IndexedBlock: ...

def __new__(cls, *args, **kwds):
if cls != Block:
return super(Block, cls).__new__(cls)
Expand Down Expand Up @@ -2251,6 +2265,11 @@ class IndexedBlock(Block):
def __init__(self, *args, **kwds):
Block.__init__(self, *args, **kwds)

@overload
def __getitem__(self, index) -> _BlockData: ...

__getitem__ = IndexedComponent.__getitem__ # type: ignore


#
# Deprecated functions.
Expand Down
19 changes: 19 additions & 0 deletions pyomo/core/base/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from __future__ import annotations
import sys
import logging
from weakref import ref as weakref_ref
from pyomo.common.pyomo_typing import overload
from typing import Union, Type

from pyomo.common.deprecation import RenamedClass
from pyomo.common.errors import DeveloperError
Expand Down Expand Up @@ -42,6 +44,7 @@
ActiveIndexedComponent,
UnindexedComponent_set,
rule_wrapper,
IndexedComponent,
)
from pyomo.core.base.set import Set
from pyomo.core.base.disable_methods import disable_methods
Expand Down Expand Up @@ -728,6 +731,17 @@ class Infeasible(object):
Violated = Infeasible
Satisfied = Feasible

@overload
def __new__(
cls: Type[Constraint], *args, **kwds
) -> Union[ScalarConstraint, IndexedConstraint]: ...

@overload
def __new__(cls: Type[ScalarConstraint], *args, **kwds) -> ScalarConstraint: ...

@overload
def __new__(cls: Type[IndexedConstraint], *args, **kwds) -> IndexedConstraint: ...

def __new__(cls, *args, **kwds):
if cls != Constraint:
return super(Constraint, cls).__new__(cls)
Expand Down Expand Up @@ -1020,6 +1034,11 @@ def add(self, index, expr):
"""Add a constraint with a given index."""
return self.__setitem__(index, expr)

@overload
def __getitem__(self, index) -> _GeneralConstraintData: ...

__getitem__ = IndexedComponent.__getitem__ # type: ignore


@ModelComponentFactory.register("A list of constraint expressions.")
class ConstraintList(IndexedConstraint):
Expand Down
4 changes: 2 additions & 2 deletions pyomo/core/base/indexed_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pyomo.core.base as BASE
from pyomo.core.base.indexed_component_slice import IndexedComponent_slice
from pyomo.core.base.initializer import Initializer
from pyomo.core.base.component import Component, ActiveComponent
from pyomo.core.base.component import Component, ActiveComponent, ComponentData
from pyomo.core.base.config import PyomoOptions
from pyomo.core.base.enums import SortComponents
from pyomo.core.base.global_set import UnindexedComponent_set
Expand Down Expand Up @@ -606,7 +606,7 @@ def iteritems(self):
"""Return a list (index,data) tuples from the dictionary"""
return self.items()

def __getitem__(self, index):
def __getitem__(self, index) -> ComponentData:
"""
This method returns the data corresponding to the given index.
"""
Expand Down
15 changes: 14 additions & 1 deletion pyomo/core/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from __future__ import annotations
import sys
import types
import logging
from weakref import ref as weakref_ref
from pyomo.common.pyomo_typing import overload
from typing import Union, Type

from pyomo.common.autoslots import AutoSlots
from pyomo.common.deprecation import deprecation_warning, RenamedClass
Expand Down Expand Up @@ -291,6 +293,17 @@ class NoValue(object):

pass

@overload
def __new__(
cls: Type[Param], *args, **kwds
) -> Union[ScalarParam, IndexedParam]: ...

@overload
def __new__(cls: Type[ScalarParam], *args, **kwds) -> ScalarParam: ...

@overload
def __new__(cls: Type[IndexedParam], *args, **kwds) -> IndexedParam: ...

def __new__(cls, *args, **kwds):
if cls != Param:
return super(Param, cls).__new__(cls)
Expand Down Expand Up @@ -983,7 +996,7 @@ def _create_objects_for_deepcopy(self, memo, component_list):
# between potentially variable GetItemExpression objects and
# "constant" GetItemExpression objects. That will need to wait for
# the expression rework [JDS; Nov 22].
def __getitem__(self, args):
def __getitem__(self, args) -> _ParamData:
try:
return super().__getitem__(args)
except:
Expand Down
16 changes: 15 additions & 1 deletion pyomo/core/base/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from __future__ import annotations
import inspect
import itertools
import logging
import math
import sys
import weakref
from pyomo.common.pyomo_typing import overload
from typing import Union, Type, Any as typingAny
from collections.abc import Iterator

from pyomo.common.collections import ComponentSet
from pyomo.common.deprecation import deprecated, deprecation_warning, RenamedClass
Expand Down Expand Up @@ -569,7 +572,7 @@ def isordered(self):
def subsets(self, expand_all_set_operators=None):
return iter((self,))

def __iter__(self):
def __iter__(self) -> Iterator[typingAny]:
"""Iterate over the set members
Raises AttributeError for non-finite sets. This must be
Expand Down Expand Up @@ -1967,6 +1970,12 @@ class SortedOrder(object):
_ValidOrderedAuguments = {True, False, InsertionOrder, SortedOrder}
_UnorderedInitializers = {set}

@overload
def __new__(cls: Type[Set], *args, **kwds) -> Union[_SetData, IndexedSet]: ...

@overload
def __new__(cls: Type[OrderedScalarSet], *args, **kwds) -> OrderedScalarSet: ...

def __new__(cls, *args, **kwds):
if cls is not Set:
return super(Set, cls).__new__(cls)
Expand Down Expand Up @@ -2373,6 +2382,11 @@ def data(self):
"Return a dict containing the data() of each Set in this IndexedSet"
return {k: v.data() for k, v in self.items()}

@overload
def __getitem__(self, index) -> _SetData: ...

__getitem__ = IndexedComponent.__getitem__ # type: ignore


class FiniteScalarSet(_FiniteSetData, Set):
def __init__(self, **kwds):
Expand Down
15 changes: 13 additions & 2 deletions pyomo/core/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

from __future__ import annotations
import logging
import sys
from pyomo.common.pyomo_typing import overload
from weakref import ref as weakref_ref
from typing import Union, Type

from pyomo.common.deprecation import RenamedClass
from pyomo.common.log import is_debug_set
Expand Down Expand Up @@ -668,6 +670,15 @@ class Var(IndexedComponent, IndexedComponent_NDArrayMixin):

_ComponentDataClass = _GeneralVarData

@overload
def __new__(cls: Type[Var], *args, **kwargs) -> Union[ScalarVar, IndexedVar]: ...

@overload
def __new__(cls: Type[ScalarVar], *args, **kwargs) -> ScalarVar: ...

@overload
def __new__(cls: Type[IndexedVar], *args, **kwargs) -> IndexedVar: ...

def __new__(cls, *args, **kwargs):
if cls is not Var:
return super(Var, cls).__new__(cls)
Expand All @@ -688,7 +699,7 @@ def __init__(
dense=True,
units=None,
name=None,
doc=None
doc=None,
): ...

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -1046,7 +1057,7 @@ def domain(self, domain):
# between potentially variable GetItemExpression objects and
# "constant" GetItemExpression objects. That will need to wait for
# the expression rework [JDS; Nov 22].
def __getitem__(self, args):
def __getitem__(self, args) -> _GeneralVarData:
try:
return super().__getitem__(args)
except RuntimeError:
Expand Down

0 comments on commit 4f1e20c

Please sign in to comment.