Skip to content

Commit

Permalink
refactor unique constraint bad request logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cc-jj committed Feb 10, 2022
1 parent f4c8beb commit 4702563
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 48 deletions.
55 changes: 9 additions & 46 deletions src/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from src import models, schemas


def _create_unique_constrain_error_msg(exc: IntegrityError) -> Optional[str]:
def create_unique_constrain_error_msg(exc: IntegrityError) -> Optional[str]:
pattern = r"^UNIQUE constraint failed: ([a-z_]+)\.([a-z_]+)$"
for arg in exc.orig.args:
assert isinstance(arg, str)
Expand All @@ -22,11 +22,6 @@ def _create_unique_constrain_error_msg(exc: IntegrityError) -> Optional[str]:
return None


def _handle_integrity_error(exc: IntegrityError):
if error_msg := _create_unique_constrain_error_msg(exc):
raise HTTPException(400, error_msg)


# User


Expand All @@ -40,11 +35,7 @@ def read_user(db: Session, username: str) -> Optional[models.User]:
def create_customer(db: Session, customer: schemas.CustomerCreate) -> models.Customer:
db_customer = models.Customer(**customer.dict())
db.add(db_customer)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_customer


Expand Down Expand Up @@ -78,11 +69,7 @@ def update_customer(
for attr, value in customer.dict(exclude={"id"}, exclude_unset=True).items():
setattr(db_customer, attr, value)
db.add(db_customer)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_customer


Expand All @@ -92,11 +79,7 @@ def update_customer(
def create_menu_category(db: Session, category: schemas.MenuCategoryCreate) -> models.MenuCategory:
db_category = models.MenuCategory(**category.dict())
db.add(db_category)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_category


Expand All @@ -119,11 +102,7 @@ def update_menu_category(
for attr, value in category.dict(exclude={"id"}, exclude_unset=True).items():
setattr(db_category, attr, value)
db.add(db_category)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_category


Expand All @@ -133,11 +112,7 @@ def update_menu_category(
def create_menu_item(db: Session, menu_item: schemas.MenuItemCreate) -> models.MenuItem:
db_menu_item = models.MenuItem(**menu_item.dict())
db.add(db_menu_item)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_menu_item


Expand All @@ -161,11 +136,7 @@ def update_menu_item(
for attr, value in menu_item.dict(exclude={"id"}, exclude_unset=True).items():
setattr(db_menu_item, attr, value)
db.add(db_menu_item)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_menu_item


Expand All @@ -175,11 +146,7 @@ def update_menu_item(
def create_campaign(db: Session, campaign: schemas.CampaignCreate) -> models.Campaign:
db_campaign = models.Campaign(**campaign.dict())
db.add(db_campaign)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_campaign


Expand All @@ -200,11 +167,7 @@ def update_campaign(
for attr, value in campaign.dict(exclude={"id"}, exclude_unset=True).items():
setattr(db_campaign, attr, value)
db.add(db_campaign)
try:
db.commit()
except IntegrityError as exc:
_handle_integrity_error(exc)
raise
db.commit()
return db_campaign


Expand Down
12 changes: 10 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import sys
import traceback

from fastapi import FastAPI, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi_pagination import add_pagination
from sqlalchemy.exc import IntegrityError

from src import routes, settings
from src import crud, routes, settings

logger = logging.getLogger("bakery")

Expand All @@ -22,6 +23,13 @@ def startup():
logger.info("Settings\n%s", "\n".join(pretty_settings))


@app.exception_handler(IntegrityError)
def handle_integrity_error(request: Request, exc: IntegrityError):
"""Convert DB UniqueConstraint failures to 400 BadRequest failures"""
if error_msg := crud.create_unique_constrain_error_msg(exc):
return JSONResponse({"detail": error_msg}, 400)


async def catch_exceptions_middleware(request: Request, call_next):
"""
Ensure CORS headers are added to response when an unhandled exception occurs:
Expand Down

0 comments on commit 4702563

Please sign in to comment.