Skip to content

Commit

Permalink
feat: allow STI child entities to have non-nullable relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
simPod committed Aug 1, 2023
1 parent 710dde8 commit 056224e
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 16 deletions.
17 changes: 12 additions & 5 deletions lib/Doctrine/ORM/Tools/SchemaTool.php
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$mapping['joinColumns'],
$table,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand Down Expand Up @@ -608,6 +609,7 @@ private function gatherRelationsSql(
$joinTable['joinColumns'],
$theJoinTable,
$class,
$class,
$mapping,
$primaryKeyColumns,
$addedFks,
Expand All @@ -618,6 +620,7 @@ private function gatherRelationsSql(
$this->gatherRelationJoinColumns(
$joinTable['inverseJoinColumns'],
$theJoinTable,
$class,
$foreignClass,
$mapping,
$primaryKeyColumns,
Expand Down Expand Up @@ -685,6 +688,7 @@ private function gatherRelationJoinColumns(
array $joinColumns,
Table $theJoinTable,
ClassMetadata $class,
ClassMetadata $foreignClass,
array $mapping,
array &$primaryKeyColumns,
array &$addedFks,
Expand All @@ -693,12 +697,12 @@ private function gatherRelationJoinColumns(
$localColumns = [];
$foreignColumns = [];
$fkOptions = [];
$foreignTableName = $this->quoteStrategy->getTableName($class, $this->platform);
$foreignTableName = $this->quoteStrategy->getTableName($foreignClass, $this->platform);
$uniqueConstraints = [];

foreach ($joinColumns as $joinColumn) {
[$definingClass, $referencedFieldName] = $this->getDefiningClass(
$class,
$foreignClass,
$joinColumn['referencedColumnName']
);

Expand All @@ -710,10 +714,10 @@ private function gatherRelationJoinColumns(
);
}

$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $class, $this->platform);
$quotedColumnName = $this->quoteStrategy->getJoinColumnName($joinColumn, $foreignClass, $this->platform);
$quotedRefColumnName = $this->quoteStrategy->getReferencedJoinColumnName(
$joinColumn,
$class,
$foreignClass,
$this->platform
);

Expand All @@ -736,7 +740,10 @@ private function gatherRelationJoinColumns(
$columnOptions['columnDefinition'] = $fieldMapping['columnDefinition'];
}

if (isset($joinColumn['nullable'])) {
if (
isset($joinColumn['nullable'])
&& ! ($class->isInheritanceTypeSingleTable() && $class->parentClasses)
) {
$columnOptions['notnull'] = ! $joinColumn['nullable'];
}

Expand Down
30 changes: 30 additions & 0 deletions tests/Doctrine/Tests/Models/Company/CompanyCarContract.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<?php

declare(strict_types=1);

namespace Doctrine\Tests\Models\Company;

use Doctrine\ORM\Mapping as ORM;
use Doctrine\ORM\Mapping\Entity;

/** @Entity */
class CompanyCarContract extends CompanyContract
{
/**
* @ORM\ManyToOne(targetEntity="CompanyCar")
* @ORM\JoinColumn(nullable=false, onDelete="CASCADE")
*
* @var CompanyCar
*/
private $companyCar;

public function calculatePrice(): int
{
return 0;
}

public function setCompanyCar(CompanyCar $companyCar): void
{
$this->companyCar = $companyCar;
}
}
5 changes: 3 additions & 2 deletions tests/Doctrine/Tests/Models/Company/CompanyContract.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
* @DiscriminatorMap({
* "fix" = "CompanyFixContract",
* "flexible" = "CompanyFlexContract",
* "flexultra" = "CompanyFlexUltraContract"
* "flexultra" = "CompanyFlexUltraContract",
* "car" = "CompanyCarContract"
* })
* @NamedNativeQueries({
* @NamedNativeQuery(
Expand Down Expand Up @@ -85,7 +86,7 @@
#[ORM\Table(name: 'company_contracts')]
#[ORM\InheritanceType('SINGLE_TABLE')]
#[ORM\DiscriminatorColumn(name: 'discr', type: 'string')]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract'])]
#[ORM\DiscriminatorMap(['fix' => 'CompanyFixContract', 'flexible' => 'CompanyFlexContract', 'flexultra' => 'CompanyFlexUltraContract', 'car' => 'CompanyCarContract'])]
#[ORM\EntityListeners(['CompanyContractListener'])]
abstract class CompanyContract
{
Expand Down
21 changes: 21 additions & 0 deletions tests/Doctrine/Tests/ORM/Functional/SingleTableInheritanceTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
use Doctrine\Common\Collections\Criteria;
use Doctrine\ORM\Mapping\ClassMetadata;
use Doctrine\ORM\Persisters\MatchingAssociationFieldRequiresObject;
use Doctrine\Tests\Models\Company\CompanyCar;
use Doctrine\Tests\Models\Company\CompanyCarContract;
use Doctrine\Tests\Models\Company\CompanyContract;
use Doctrine\Tests\Models\Company\CompanyEmployee;
use Doctrine\Tests\Models\Company\CompanyFixContract;
Expand Down Expand Up @@ -419,4 +421,23 @@ public function testEagerLoadInheritanceHierarchy(): void

self::assertFalse($this->isUninitializedObject($contract->getSalesPerson()));
}

public function testChildCanHaveNonNullableRelation(): void
{
$companyCar = new CompanyCar('BMW');
$fixContract = new CompanyFixContract();
$carContract = new CompanyCarContract();
$carContract->setCompanyCar($companyCar);

$this->_em->persist($fixContract);
$this->_em->persist($companyCar);
$this->_em->persist($carContract);
$this->_em->flush();
$this->_em->clear();

$repo = $this->_em->getRepository(CompanyCarContract::class);
$carContracts = $repo->findAll();

self::assertCount(1, $carContracts);
}
}
18 changes: 9 additions & 9 deletions tests/Doctrine/Tests/ORM/Query/SelectSqlGenerationTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ public function testSupportsJoinOnMultipleComponentsWithJoinedInheritanceType():

$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s LEFT JOIN Doctrine\Tests\Models\Company\CompanyEvent e WITH s.id = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN (company_events c4_ LEFT JOIN company_auctions c5_ ON c4_.id = c5_.id LEFT JOIN company_raffles c6_ ON c4_.id = c6_.id) ON (c2_.id = c4_.id) WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN (company_events c4_ LEFT JOIN company_auctions c5_ ON c4_.id = c5_.id LEFT JOIN company_raffles c6_ ON c4_.id = c6_.id) ON (c2_.id = c4_.id) WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand Down Expand Up @@ -1474,7 +1474,7 @@ public function testInheritanceTypeSingleTableInRootClassWithDisabledForcePartia
{
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
[ORMQuery::HINT_FORCE_PARTIAL_LOAD => false]
);
}
Expand All @@ -1484,7 +1484,7 @@ public function testInheritanceTypeSingleTableInRootClassWithEnabledForcePartial
{
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra')",
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6 FROM company_contracts c0_ WHERE c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')",
[ORMQuery::HINT_FORCE_PARTIAL_LOAD => true]
);
}
Expand Down Expand Up @@ -2054,7 +2054,7 @@ public function testSingleTableInheritanceLeftJoinWithCondition(): void
// Regression test for the bug
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2064,7 +2064,7 @@ public function testSingleTableInheritanceLeftJoinWithConditionAndWhere(): void
// Ensure other WHERE predicates are passed through to the main WHERE clause
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e LEFT JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id WHERE e.salary > 1000',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.salary > 1000"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id LEFT JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.salary > 1000"
);
}

Expand All @@ -2074,7 +2074,7 @@ public function testSingleTableInheritanceInnerJoinWithCondition(): void
// Test inner joins too
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyEmployee e INNER JOIN Doctrine\Tests\Models\Company\CompanyContract c WITH c.salesPerson = e.id',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id INNER JOIN company_contracts c0_ ON (c0_.salesPerson_id = c2_.id) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2085,7 +2085,7 @@ public function testSingleTableInheritanceLeftJoinNonAssociationWithConditionAnd
// the where clause when not joining onto that table
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c LEFT JOIN Doctrine\Tests\Models\Company\CompanyEmployee e WITH e.id = c.salesPerson WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ LEFT JOIN (company_employees c1_ INNER JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id) ON (c2_.id = c0_.salesPerson_id) WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2097,7 +2097,7 @@ public function testSingleTableInheritanceJoinCreatesOnCondition(): void
// via a join association
$this->assertSqlGeneration(
'SELECT c FROM Doctrine\Tests\Models\Company\CompanyContract c JOIN c.salesPerson s WHERE c.completed = true',
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra')"
"SELECT c0_.id AS id_0, c0_.completed AS completed_1, c0_.fixPrice AS fixPrice_2, c0_.hoursWorked AS hoursWorked_3, c0_.pricePerHour AS pricePerHour_4, c0_.maxPrice AS maxPrice_5, c0_.discr AS discr_6, c0_.salesPerson_id AS salesPerson_id_7, c0_.companyCar_id AS companyCar_id_8 FROM company_contracts c0_ INNER JOIN company_employees c1_ ON c0_.salesPerson_id = c1_.id LEFT JOIN company_persons c2_ ON c1_.id = c2_.id LEFT JOIN company_managers c3_ ON c1_.id = c3_.id WHERE (c0_.completed = 1) AND c0_.discr IN ('fix', 'flexible', 'flexultra', 'car')"
);
}

Expand All @@ -2109,7 +2109,7 @@ public function testSingleTableInheritanceCreatesOnConditionAndWhere(): void
// into the ON clause of the join
$this->assertSqlGeneration(
'SELECT e, COUNT(c) FROM Doctrine\Tests\Models\Company\CompanyEmployee e JOIN e.contracts c WHERE e.department = :department',
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, c2_.title AS title_5, COUNT(c3_.id) AS sclr_6, c0_.discr AS discr_7, c0_.spouse_id AS spouse_id_8, c2_.car_id AS car_id_9 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id LEFT JOIN company_managers c2_ ON c1_.id = c2_.id INNER JOIN company_contract_employees c4_ ON c1_.id = c4_.employee_id INNER JOIN company_contracts c3_ ON c3_.id = c4_.contract_id AND c3_.discr IN ('fix', 'flexible', 'flexultra') WHERE c1_.department = ?",
"SELECT c0_.id AS id_0, c0_.name AS name_1, c1_.salary AS salary_2, c1_.department AS department_3, c1_.startDate AS startDate_4, c2_.title AS title_5, COUNT(c3_.id) AS sclr_6, c0_.discr AS discr_7, c0_.spouse_id AS spouse_id_8, c2_.car_id AS car_id_9 FROM company_employees c1_ INNER JOIN company_persons c0_ ON c1_.id = c0_.id LEFT JOIN company_managers c2_ ON c1_.id = c2_.id INNER JOIN company_contract_employees c4_ ON c1_.id = c4_.employee_id INNER JOIN company_contracts c3_ ON c3_.id = c4_.contract_id AND c3_.discr IN ('fix', 'flexible', 'flexultra', 'car') WHERE c1_.department = ?",
[],
['department' => 'foobar']
);
Expand Down

0 comments on commit 056224e

Please sign in to comment.