Skip to content

Commit

Permalink
Implement ConfigurationHelper classes to simplify and deduplicate logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jacklinke committed Jun 4, 2024
1 parent 1a82e5e commit 4483f8b
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 132 deletions.
5 changes: 3 additions & 2 deletions src/django_segments/helpers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def __init__(self, obj: Union[BaseSpan, BaseSegment]):
self.obj = obj

# Get the field_type, which tells us the type of range field used in the model
segment_range = self.obj.segment_range
self.field_type = segment_range.get_internal_type()
segment_range = getattr(self.obj, 'segment_range', None)
if segment_range:
self.field_type = segment_range.get_internal_type()

def validate_value_type(self, value):
"""Validate the type of the provided value against the model's field_type."""
Expand Down
5 changes: 3 additions & 2 deletions src/django_segments/helpers/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.db import transaction

from django_segments.helpers.base import BaseHelper
from django_segments.models.base import SegmentConfigurationHelper


logger = logging.getLogger(__name__)
Expand All @@ -25,7 +26,7 @@ class SegmentHelperBase(BaseHelper):

def __init__(self, obj: BaseSegment):
super().__init__(obj)
self.config_dict = self.obj.span.get_config_dict()
self.config_dict = SegmentConfigurationHelper(obj).get_config_dict()

def validate_segment_range(self, segment_range):
"""Validate the segment range based on the span and any adjacent segments."""
Expand Down Expand Up @@ -79,7 +80,7 @@ def create(self, segment_range, *args, **kwargs):

# Adjust adjacent segments if not allowing segment gaps
if not self.config_dict["allow_segment_gaps"]:
self.adjust_adjacent_segments(segment_instance)
self.adjust_adjacent_segments()

return segment_instance

Expand Down
5 changes: 3 additions & 2 deletions src/django_segments/helpers/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from psycopg2.extras import Range # psycopg2's base range class

from django_segments.helpers.base import BaseHelper
from django_segments.models.base import SpanConfigurationHelper
from django_segments.signals import segment_post_create
from django_segments.signals import segment_post_delete
from django_segments.signals import segment_post_delete_or_soft_delete
Expand Down Expand Up @@ -47,7 +48,7 @@ class SpanHelperBase(BaseHelper): # pylint: disable=R0903

def __init__(self, obj: BaseSpan):
super().__init__(obj)
self.config_dict = self.obj.get_config_dict()
self.config_dict = SpanConfigurationHelper.get_config_dict(obj)


class CreateSpanHelper:
Expand All @@ -58,7 +59,7 @@ class CreateSpanHelper:

def __init__(self, model_class: type[BaseSpan]):
self.model_class = model_class
self.config_dict = self.model_class.get_config_dict()
self.config_dict = SpanConfigurationHelper.get_config_dict(model_class)

@transaction.atomic
def create(self, *args, range_value: Range = None, **kwargs):
Expand Down
171 changes: 93 additions & 78 deletions src/django_segments/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,86 @@ def set_upper_boundary(self, value):
class ConcreteModelValidationHelper: # pylint: disable=R0903
"""Helper class for validating that models are concrete."""

def __init__(self, model: type[models.Model]) -> None:
"""Initialize the helper with the model and error class."""
self.model = model

def check_model_is_concrete(self) -> None:
"""Ensure that the model is not abstract."""
if self.model._meta.abstract: # pylint: disable=W0212
@staticmethod
def check_model_is_concrete(model) -> None:
"""Check that the model is not abstract."""
if model._meta.abstract: # pylint: disable=W0212
raise IncorrectSubclassError("Concrete subclasses must not be abstract")


class SpanConfigurationHelper:
"""Helper class for retrieving Span model configurations."""

@staticmethod
def get_config_attr(model, attr_name: str, default):
"""Given an attribute name and default value, returns the attribute value from the SpanConfig class."""
return getattr(model.SpanConfig, attr_name, default)

@staticmethod
def get_range_type(model):
"""Return the range type for the span model."""
range_type = SpanConfigurationHelper.get_config_attr(model, "range_type", None)
if range_type is None:
raise IncorrectRangeTypeError(f"Range type not defined for {model.__class__.__name__}")
if range_type.__name__ not in POSTGRES_RANGE_FIELDS:
raise IncorrectRangeTypeError(
f"Unsupported range type: {range_type} not in {POSTGRES_RANGE_FIELDS.keys()} for "
f"{model.__class__.__name__}"
)
return range_type

@staticmethod
def get_config_dict(model) -> dict:
"""Return the configuration options for the span as a dictionary."""
return {
"allow_span_gaps": SpanConfigurationHelper.get_config_attr(model, "allow_span_gaps", ALLOW_SPAN_GAPS),
"allow_segment_gaps": SpanConfigurationHelper.get_config_attr(
model, "allow_segment_gaps", ALLOW_SEGMENT_GAPS
),
"soft_delete": SpanConfigurationHelper.get_config_attr(model, "soft_delete", SOFT_DELETE),
"range_type": SpanConfigurationHelper.get_range_type(model),
}


class SegmentConfigurationHelper:
"""Helper class for retrieving Segment model configurations."""

@staticmethod
def get_config_attr(model, attr_name: str, default):
"""Given an attribute name and default value, returns the attribute value from the SegmentConfig class."""
return getattr(model.SegmentConfig, attr_name, default)

@staticmethod
def get_span_model(model):
"""Return the span model for the segment model."""
span_model = SegmentConfigurationHelper.get_config_attr(model, "span_model", None)
if span_model is None:
raise IncorrectSubclassError(_(f"Span model not defined for {model.__class__.__name__}"))
if "AbstractSpan" not in [base.__name__ for base in span_model.__bases__]:
raise IncorrectSubclassError(_(f"Span model ({span_model}) must be a subclass of AbstractSpan for {model}"))
return span_model

@staticmethod
def get_config_dict(model) -> dict:
"""Return a dictionary of configuration options."""
return {
"span_model": SegmentConfigurationHelper.get_span_model(model),
"soft_delete": getattr(
SegmentConfigurationHelper.get_span_model(model).SpanConfig, "soft_delete", SOFT_DELETE
),
"previous_field_on_delete": SegmentConfigurationHelper.get_config_attr(
model, "previous_field_on_delete", PREVIOUS_FIELD_ON_DELETE
),
"span_on_delete": SegmentConfigurationHelper.get_config_attr(model, "span_on_delete", SPAN_ON_DELETE),
"span_related_name": SegmentConfigurationHelper.get_config_attr(
model, "span_related_name", DEFAULT_RELATED_NAME
),
"span_related_query_name": SegmentConfigurationHelper.get_config_attr(
model, "span_related_query_name", DEFAULT_RELATED_QUERY_NAME
),
}


class AbstractSpanMetaclass(ModelBase): # pylint: disable=R0903
"""Metaclass for AbstractSpan."""

Expand All @@ -97,46 +167,20 @@ def __new__(cls, name, bases, attrs, **kwargs):

model = super().__new__(cls, name, bases, attrs, **kwargs) # pylint: disable=E1121

def get_config_attr(attr_name: str, default):
"""Given an attribute name and default value, returns the attribute value from the SpanConfig class."""
return getattr(model.SpanConfig, attr_name, None) if hasattr(model.SpanConfig, attr_name) else default

def get_range_type():
"""Return the range type for the span model."""
range_type = get_config_attr("range_type", None)
if range_type is None:
raise IncorrectRangeTypeError(f"Range type not defined for {model.__class__.__name__}")
if range_type.__name__ not in POSTGRES_RANGE_FIELDS.keys():
raise IncorrectRangeTypeError(
f"Unsupported range type: {range_type} not in {POSTGRES_RANGE_FIELDS.keys()=} for "
f"{model.__class__.__name__}"
)

return range_type

def get_config_dict() -> dict[str, bool]:
"""Return the configuration options for the span as a dictionary."""

return {
"allow_span_gaps": get_config_attr("allow_span_gaps", ALLOW_SPAN_GAPS),
"allow_segment_gaps": get_config_attr("allow_segment_gaps", ALLOW_SEGMENT_GAPS),
"soft_delete": get_config_attr("soft_delete", SOFT_DELETE),
"range_type": get_range_type(),
}

for base in bases:
if base.__name__ == "AbstractSpan":
# Call get_range_type to ensure that the range type is defined
get_range_type()
SpanConfigurationHelper.get_range_type(model)

# Ensure that the model is not abstract
concrete_validation_helper = ConcreteModelValidationHelper(model)
concrete_validation_helper.check_model_is_concrete()
ConcreteModelValidationHelper.check_model_is_concrete(model)

config_dict = SpanConfigurationHelper.get_config_dict(model)

# Add the initial_range and current_range fields to the model
model.add_to_class(
"initial_range",
get_config_dict().get("range_type")(
config_dict["range_type"](
_("Initial Range"),
blank=True,
null=True,
Expand All @@ -145,7 +189,7 @@ def get_config_dict() -> dict[str, bool]:

model.add_to_class(
"current_range",
get_config_dict().get("range_type")(
config_dict["range_type"](
_("Current Range"),
blank=True,
null=True,
Expand All @@ -160,7 +204,7 @@ def get_config_dict() -> dict[str, bool]:
]

# If we are using soft delete, add a deleted_at field to the model
if get_config_dict().get("soft_delete"):
if config_dict["soft_delete"]:
model.add_to_class(
"deleted_at",
models.DateTimeField(
Expand Down Expand Up @@ -189,44 +233,15 @@ def __new__(cls, name, bases, attrs, **kwargs):

model = super().__new__(cls, name, bases, attrs, **kwargs) # pylint: disable=E1121

def get_config_attr(attr_name: str, default):
"""Given an attribute name and default value, returns the attribute value from the SegmentConfig class."""
return getattr(model.SegmentConfig, attr_name, None) if hasattr(model.SegmentConfig, attr_name) else default

def get_span_model():
"""Return the span model for the segment model."""
span_model = get_config_attr("span_model", None)
if span_model is None:
raise IncorrectSubclassError(_(f"Span model not defined for {model.__class__.__name__}"))

# Check if "AbstractSpan" is in the base names for span_model (i.e. if span_model is a subclass of Abstract
if "AbstractSpan" not in [base.__name__ for base in span_model.__bases__]:
raise IncorrectSubclassError(
_(f"Span model ({span_model}) must be a subclass of AbstractSpan for {model}")
)

return span_model

def get_config_dict() -> dict[str, bool]:
"""Return a dictionary of configuration options."""

return {
"span_model": get_span_model(),
"soft_delete": getattr(get_span_model().SpanConfig, "soft_delete", SOFT_DELETE),
"previous_field_on_delete": get_config_attr("previous_field_on_delete", PREVIOUS_FIELD_ON_DELETE),
"span_on_delete": get_config_attr("span_on_delete", SPAN_ON_DELETE),
"span_related_name": get_config_attr("span_related_name", DEFAULT_RELATED_NAME),
"span_related_query_name": get_config_attr("span_related_query_name", DEFAULT_RELATED_QUERY_NAME),
}

for base in bases:
if base.__name__ == "AbstractSegment":
# Call get_span_model to ensure that the span model is defined
get_span_model()
SegmentConfigurationHelper.get_span_model(model)

# Ensure that the model is not abstract
concrete_validation_helper = ConcreteModelValidationHelper(model)
concrete_validation_helper.check_model_is_concrete()
ConcreteModelValidationHelper.check_model_is_concrete(model)

config_dict = SegmentConfigurationHelper.get_config_dict(model)

# Add the segment_range, span, and previous_segment fields to the model
model.add_to_class(
Expand All @@ -244,9 +259,9 @@ def get_config_dict() -> dict[str, bool]:
model.SegmentConfig.span_model,
null=True,
blank=True,
on_delete=get_config_dict().get("span_on_delete"),
related_name=get_config_dict().get("span_related_name"),
related_query_name=get_config_dict().get("span_related_query_name"),
on_delete=config_dict["span_on_delete"],
related_name=config_dict["span_related_name"],
related_query_name=config_dict["span_related_query_name"],
),
)

Expand All @@ -256,7 +271,7 @@ def get_config_dict() -> dict[str, bool]:
model,
null=True,
blank=True,
on_delete=get_config_dict().get("previous_field_on_delete"),
on_delete=config_dict["previous_field_on_delete"],
related_name="next_segment",
),
)
Expand All @@ -270,7 +285,7 @@ def get_config_dict() -> dict[str, bool]:
model.Meta.indexes.append(models.Index(fields=["segment_range"]))

# If we are using soft delete, add a deleted_at field to the model
if get_config_dict().get("soft_delete"):
if config_dict["soft_delete"]:
model.add_to_class(
"deleted_at",
models.DateTimeField(
Expand Down
35 changes: 2 additions & 33 deletions src/django_segments/models/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django_segments.app_settings import SOFT_DELETE
from django_segments.app_settings import SPAN_ON_DELETE
from django_segments.models.base import AbstractSegmentMetaclass
from django_segments.models.base import SegmentConfigurationHelper
from django_segments.models.base import boundary_helper_factory


Expand Down Expand Up @@ -37,43 +38,16 @@ class MyOtherSegment(AbstractSegment):

_set_lower_boundary, _set_upper_boundary = boundary_helper_factory("segment_range")

# previous_segment = models.OneToOneField( # Should use the value from config
# "self",
# null=True,
# blank=True,
# related_name="next_segment",
# on_delete=PREVIOUS_FIELD_ON_DELETE,
# )

class Meta: # pylint: disable=C0115 disable=R0903
abstract = True

class SegmentConfig: # pylint: disable=R0903
"""Configuration options for the segment."""

span_model = None

def get_config_attr(self, attr_name: str, default):
"""Given an attribute name and default value, returns the attribute value from the SegmentConfig class."""
return getattr(self.SegmentConfig, attr_name, None) if hasattr(self.SegmentConfig, attr_name) else default

def get_config_dict(self) -> dict[str, bool]:
"""Return a dictionary of configuration options."""

span_model = self.get_config_attr("span_model", None) # Previously verified in the metaclass

return {
"span_model": span_model,
"soft_delete": self.span_model.get_config_attr("soft_delete", SOFT_DELETE),
"previous_field_on_delete": self.get_config_attr("previous_field_on_delete", PREVIOUS_FIELD_ON_DELETE),
"previous_field_related_name": self.get_config_attr("previous_field_related_name", DEFAULT_RELATED_NAME),
"previous_field_related_query_name": self.get_config_attr(
"previous_field_related_query_name", DEFAULT_RELATED_QUERY_NAME
),
"span_on_delete": self.get_config_attr("span_on_delete", SPAN_ON_DELETE),
"span_related_name": self.get_config_attr("span_related_name", DEFAULT_RELATED_NAME),
"span_related_query_name": self.get_config_attr("span_related_query_name", DEFAULT_RELATED_QUERY_NAME),
}
return SegmentConfigurationHelper.get_config_dict(self)

def set_lower_boundary(self, value) -> None:
"""Set the lower boundary of the segment range field."""
Expand Down Expand Up @@ -122,8 +96,3 @@ def is_first_and_last(self):
def is_internal(self):
"""Return True if the segment is not the first or last segment in the span."""
return not self.is_first_and_last

@property
def span(self):
"""Return the span associated with the segment."""
return getattr(self, self.span.name)
Loading

0 comments on commit 4483f8b

Please sign in to comment.