diff --git a/orm/models.py b/orm/models.py index b402814..c49b4ff 100644 --- a/orm/models.py +++ b/orm/models.py @@ -88,6 +88,9 @@ def __new__(cls, name, bases, attrs): if field.primary_key: model_class.pkname = name + unique_together = attrs.get("unique_together", ()) + setattr(model_class, "unique_together", unique_together) + return model_class @property @@ -486,6 +489,7 @@ def _prepare_order_by(self, order_by: str): class Model(metaclass=ModelMeta): objects = QuerySet() + unique_together: typing.Sequence[typing.Union[typing.Sequence[str], str]] def __init__(self, **kwargs): if "pk" in kwargs: @@ -515,10 +519,21 @@ def __str__(self): def build_table(cls): tablename = cls.tablename metadata = cls.registry._metadata + unique_together = cls.unique_together + columns = [] for name, field in cls.fields.items(): columns.append(field.get_column(name)) - return sqlalchemy.Table(tablename, metadata, *columns, extend_existing=True) + + uniques = [] + for fields_set in unique_together: + unique_constraint = cls.__get_unique_constraint(fields_set) + if unique_constraint is not None: + uniques.append(unique_constraint) + + return sqlalchemy.Table( + tablename, metadata, *columns, *uniques, extend_existing=True + ) @property def table(self) -> sqlalchemy.Table: @@ -580,6 +595,24 @@ def _from_row(cls, row, select_related=[]): return cls(**item) + @classmethod + def __get_unique_constraint( + cls, + columns: typing.Union[typing.Sequence[str], str], + ) -> typing.Optional[sqlalchemy.UniqueConstraint]: + """ + Returned the Unique Constraint of SQLAlchemy. + + :columns: Must be `str` or `Sequence[List or Tupe]` of Strings + + If Type of 'columns' Didn't Match Above Nothing to Return Output + """ + + if isinstance(columns, str): + return sqlalchemy.UniqueConstraint(columns) + elif isinstance(columns, (tuple, list)): + return sqlalchemy.UniqueConstraint(*columns) + def __setattr__(self, key, value): if key in self.fields: # Setting a relationship to a raw pk value should set a diff --git a/tests/test_columns.py b/tests/test_columns.py index 278aecd..8389ebe 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -57,6 +57,16 @@ class User(orm.Model): } +class Customer(orm.Model): + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "fname": orm.String(max_length=100), + "lname": orm.String(max_length=100), + } + unique_together = (("fname", "lname"),) + + @pytest.fixture(autouse=True, scope="module") async def create_test_database(): await models.create_all() @@ -159,3 +169,22 @@ async def test_bulk_create(): assert products[1].data == {"foo": 456} assert products[1].value == 456.789 assert products[1].status == StatusEnum.DRAFT + + +async def test_unique_together_fname_lname(): + sepehr = await Customer.objects.create(fname="Sepehr", lname="Bazyar") + sepehr: Customer = await Customer.objects.get(pk=sepehr.pk) + + farzane = await Customer.objects.create(fname="Farzane", lname="Bazyar") + farzane: Customer = await Customer.objects.get(pk=farzane.pk) + + assert sepehr.lname == farzane.lname + assert sepehr.fname == "Sepehr" + assert farzane.fname == "Farzane" + + +async def test_unique_together_fname_lname_raise_error(): + with pytest.raises(Exception): + + await Customer.objects.create(fname="Sepehr", lname="Bazyar") + await Customer.objects.create(fname="Sepehr", lname="Bazyar")