From c461e17cf6f6c9684b5d0e3bf0d29f987ebdb62f Mon Sep 17 00:00:00 2001 From: Ultraproduct <4291996-LightlessNight@users.noreply.gitlab.com> Date: Sat, 1 Aug 2020 18:45:58 -0500 Subject: [PATCH] Implement multi-column constraints. --- orm/models.py | 2 ++ tests/test_models.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/orm/models.py b/orm/models.py index e61d159..21ded05 100644 --- a/orm/models.py +++ b/orm/models.py @@ -34,6 +34,7 @@ def __new__( tablename = attrs["__tablename__"] metadata = attrs["__metadata__"] + table_args = attrs.get("__table_args__", []) pkname = None columns = [] @@ -41,6 +42,7 @@ def __new__( if field.primary_key: pkname = name columns.append(field.get_column(name)) + columns.extend(table_args) new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) new_model.__pkname__ = pkname diff --git a/tests/test_models.py b/tests/test_models.py index ddd9fde..6e4ecc7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,6 +33,17 @@ class Product(orm.Model): in_stock = orm.Boolean(default=False) +class Channel(orm.Model): + __tablename__ = "channels" + __metadata__ = metadata + __database__ = database + __table_args__ = (sqlalchemy.schema.UniqueConstraint('name', 'number'),) + + id = orm.Integer(primary_key=True) + name = orm.String(max_length=100) + number = orm.Integer() + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -216,3 +227,15 @@ async def test_model_first(): assert await User.objects.first(name="Jane") == jane assert await User.objects.filter(name="Jane").first() == jane assert await User.objects.filter(name="Lucy").first() is None + + +@async_adapter +async def test_constraints_multi_column(): + async with database: + await Channel.objects.create(name='PBS', number=1) + await Channel.objects.create(name='PBS', number=2) + await Channel.objects.create(name='CSPAN', number=2) + + from sqlite3 import IntegrityError + with pytest.raises(IntegrityError): + await Channel.objects.create(name='PBS', number=1)