Skip to content

Commit

Permalink
Fixed #35638 -- Updated validate_constraints to consider db_default.
Browse files Browse the repository at this point in the history
  • Loading branch information
shangxiao authored and sarahboyce committed Aug 5, 2024
1 parent 91a0387 commit 509763c
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 13 deletions.
34 changes: 33 additions & 1 deletion django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,9 +1250,41 @@ def as_sql(self, compiler, connection):


class DatabaseDefault(Expression):
"""Placeholder expression for the database default in an insert query."""
"""
Expression to use DEFAULT keyword during insert otherwise the underlying expression.
"""

def __init__(self, expression, output_field=None):
super().__init__(output_field)
self.expression = expression

def get_source_expressions(self):
return [self.expression]

def set_source_expressions(self, exprs):
(self.expression,) = exprs

def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
resolved_expression = self.expression.resolve_expression(
query=query,
allow_joins=allow_joins,
reuse=reuse,
summarize=summarize,
for_save=for_save,
)
# Defaults used outside an INSERT context should resolve to their
# underlying expression.
if not for_save:
return resolved_expression
return DatabaseDefault(
resolved_expression, output_field=self._output_field_or_none
)

def as_sql(self, compiler, connection):
if not connection.features.supports_default_keyword_in_insert:
return compiler.compile(self.expression)
return "DEFAULT", []


Expand Down
12 changes: 4 additions & 8 deletions django/db/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,13 +983,7 @@ def get_internal_type(self):

def pre_save(self, model_instance, add):
"""Return field's value just before saving."""
value = getattr(model_instance, self.attname)
if not connection.features.supports_default_keyword_in_insert:
from django.db.models.expressions import DatabaseDefault

if isinstance(value, DatabaseDefault):
return self._db_default_expression
return value
return getattr(model_instance, self.attname)

def get_prep_value(self, value):
"""Perform preliminary non-db specific value checks and conversions."""
Expand Down Expand Up @@ -1031,7 +1025,9 @@ def _get_default(self):
if self.db_default is not NOT_PROVIDED:
from django.db.models.expressions import DatabaseDefault

return DatabaseDefault
return lambda: DatabaseDefault(
self._db_default_expression, output_field=self
)

if (
not self.empty_strings_allowed
Expand Down
4 changes: 4 additions & 0 deletions docs/releases/5.0.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ Bugfixes
* Fixed a bug in Django 5.0 that caused a system check crash when
``ModelAdmin.date_hierarchy`` was a ``GeneratedField`` with an
``output_field`` of ``DateField`` or ``DateTimeField`` (:ticket:`35628`).

* Fixed a bug in Django 5.0 which caused constraint validation to either crash
or incorrectly raise validation errors for constraints referring to fields
using ``Field.db_default`` (:ticket:`35638`).
7 changes: 7 additions & 0 deletions tests/constraints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,10 @@ class JSONFieldModel(models.Model):

class Meta:
required_db_features = {"supports_json_field"}


class ModelWithDatabaseDefault(models.Model):
field = models.CharField(max_length=255)
field_with_db_default = models.CharField(
max_length=255, db_default=models.Value("field_with_db_default")
)
57 changes: 56 additions & 1 deletion tests/constraints/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.db import IntegrityError, connection, models
from django.db.models import F
from django.db.models.constraints import BaseConstraint, UniqueConstraint
from django.db.models.functions import Abs, Lower
from django.db.models.functions import Abs, Lower, Upper
from django.db.transaction import atomic
from django.test import SimpleTestCase, TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import ignore_warnings
Expand All @@ -14,6 +14,7 @@
ChildModel,
ChildUniqueConstraintProduct,
JSONFieldModel,
ModelWithDatabaseDefault,
Product,
UniqueConstraintConditionProduct,
UniqueConstraintDeferrable,
Expand Down Expand Up @@ -396,6 +397,33 @@ def test_check_deprecation(self):
with self.assertWarnsRegex(RemovedInDjango60Warning, msg):
self.assertIs(constraint.check, other_condition)

def test_database_default(self):
models.CheckConstraint(
condition=models.Q(field_with_db_default="field_with_db_default"),
name="check_field_with_db_default",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())

# Ensure that a check also does not silently pass with either
# FieldError or DatabaseError when checking with a db_default.
with self.assertRaises(ValidationError):
models.CheckConstraint(
condition=models.Q(
field_with_db_default="field_with_db_default", field="field"
),
name="check_field_with_db_default_2",
).validate(
ModelWithDatabaseDefault, ModelWithDatabaseDefault(field="not-field")
)

with self.assertRaises(ValidationError):
models.CheckConstraint(
condition=models.Q(field_with_db_default="field_with_db_default"),
name="check_field_with_db_default",
).validate(
ModelWithDatabaseDefault,
ModelWithDatabaseDefault(field_with_db_default="other value"),
)


class UniqueConstraintTests(TestCase):
@classmethod
Expand Down Expand Up @@ -1265,3 +1293,30 @@ def test_requires_name(self):
msg = "A unique constraint must be named."
with self.assertRaisesMessage(ValueError, msg):
models.UniqueConstraint(fields=["field"])

def test_database_default(self):
models.UniqueConstraint(
fields=["field_with_db_default"], name="unique_field_with_db_default"
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
models.UniqueConstraint(
Upper("field_with_db_default"),
name="unique_field_with_db_default_expression",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())

ModelWithDatabaseDefault.objects.create()

msg = (
"Model with database default with this Field with db default already "
"exists."
)
with self.assertRaisesMessage(ValidationError, msg):
models.UniqueConstraint(
fields=["field_with_db_default"], name="unique_field_with_db_default"
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())

msg = "Constraint “unique_field_with_db_default_expression” is violated."
with self.assertRaisesMessage(ValidationError, msg):
models.UniqueConstraint(
Upper("field_with_db_default"),
name="unique_field_with_db_default_expression",
).validate(ModelWithDatabaseDefault, ModelWithDatabaseDefault())
2 changes: 1 addition & 1 deletion tests/postgres_tests/migrations/0002_create_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class Migration(migrations.Migration):
primary_key=True,
),
),
("ints", IntegerRangeField(null=True, blank=True)),
("ints", IntegerRangeField(null=True, blank=True, db_default=(5, 10))),
("bigints", BigIntegerRangeField(null=True, blank=True)),
("decimals", DecimalRangeField(null=True, blank=True)),
("timestamps", DateTimeRangeField(null=True, blank=True)),
Expand Down
2 changes: 1 addition & 1 deletion tests/postgres_tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class LineSavedSearch(PostgreSQLModel):


class RangesModel(PostgreSQLModel):
ints = IntegerRangeField(blank=True, null=True)
ints = IntegerRangeField(blank=True, null=True, db_default=(5, 10))
bigints = BigIntegerRangeField(blank=True, null=True)
decimals = DecimalRangeField(blank=True, null=True)
timestamps = DateTimeRangeField(blank=True, null=True)
Expand Down
9 changes: 9 additions & 0 deletions tests/postgres_tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,3 +1213,12 @@ class Meta:
constraint_name,
self.get_constraints(ModelWithExclusionConstraint._meta.db_table),
)

def test_database_default(self):
constraint = ExclusionConstraint(
name="ints_equal", expressions=[("ints", RangeOperators.EQUAL)]
)
RangesModel.objects.create()
msg = "Constraint “ints_equal” is violated."
with self.assertRaisesMessage(ValidationError, msg):
constraint.validate(RangesModel, RangesModel())
2 changes: 1 addition & 1 deletion tests/validation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def clean(self):

class UniqueFieldsModel(models.Model):
unique_charfield = models.CharField(max_length=100, unique=True)
unique_integerfield = models.IntegerField(unique=True)
unique_integerfield = models.IntegerField(unique=True, db_default=42)
non_unique_field = models.IntegerField()


Expand Down
14 changes: 14 additions & 0 deletions tests/validation/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ def test_primary_key_unique_check_not_performed_when_not_adding(self):
mtv = ModelToValidate(number=10, name="Some Name")
mtv.full_clean()

def test_unique_db_default(self):
UniqueFieldsModel.objects.create(unique_charfield="foo", non_unique_field=42)
um = UniqueFieldsModel(unique_charfield="bar", non_unique_field=42)
with self.assertRaises(ValidationError) as cm:
um.full_clean()
self.assertEqual(
cm.exception.message_dict,
{
"unique_integerfield": [
"Unique fields model with this Unique integerfield already exists."
]
},
)

def test_unique_for_date(self):
Post.objects.create(
title="Django 1.0 is released",
Expand Down

0 comments on commit 509763c

Please sign in to comment.