Skip to content

Commit

Permalink
Simplify Enum (De)serialization, add get_or_404 methods to the db, en…
Browse files Browse the repository at this point in the history
…able generative tests
  • Loading branch information
StarKhan6368 committed Jun 7, 2024
1 parent e6d58ac commit 252c318
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 114 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ The open api contract for the services is defined in the [Specmatic Central Cont
- This should print the following output:

```cmd
Tests run: 19, Successes: 19, Failures: 0, Errors: 0
Tests run: 162, Successes: 162, Failures: 0, Errors: 0
```
32 changes: 24 additions & 8 deletions api/db.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from itertools import count
from typing import ClassVar

from flask import abort

from api.models import Order, OrderStatus, Product, ProductType


class Database:
_products: ClassVar[dict[int, Product]] = {
10: Product(name="XYZ Phone", product_type=ProductType.GADGET, inventory=10, id=10),
20: Product(name="Gemini", product_type=ProductType.OTHER, inventory=10, id=20),
10: Product(name="XYZ Phone", type=ProductType.GADGET, inventory=10, id=10),
20: Product(name="Gemini", type=ProductType.OTHER, inventory=10, id=20),
}

_orders: ClassVar[dict[int, Order]] = {
10: Order(product_id=10, count=2, status=OrderStatus.PENDING, id=10),
20: Order(product_id=10, count=1, status=OrderStatus.PENDING, id=20),
10: Order(productid=10, count=2, status=OrderStatus.PENDING, id=10),
20: Order(productid=10, count=1, status=OrderStatus.PENDING, id=20),
}

product_iter = count((max(_products) + 1) if _products else 1)
Expand All @@ -27,13 +29,20 @@ def find_products(name: str, product_type: ProductType | None):
return [
product
for product in Database._products.values()
if product["name"].lower() == name.lower() or product["product_type"] == product_type
if product["name"].lower() == name.lower() or product["type"] == product_type
]

@staticmethod
def find_product_by_id(product_id: int):
return Database._products.get(product_id)

@staticmethod
def find_product_by_id_or_404(product_id: int):
product = Database.find_product_by_id(product_id)
if not product:
abort(404, f"Product with {product_id} was not found")
return product

@staticmethod
def delete_product(product_id: int):
if Database._products.get(product_id):
Expand All @@ -47,7 +56,7 @@ def add_product(product: Product):
@staticmethod
def update_product(product: Product, new_data: Product):
product["name"] = new_data["name"]
product["product_type"] = new_data["product_type"]
product["type"] = new_data["type"]
product["inventory"] = new_data["inventory"]

@staticmethod
Expand All @@ -59,13 +68,20 @@ def find_orders(product_id: int, status: OrderStatus | None):
return [
order
for order in Database._orders.values()
if order["product_id"] == product_id or order["status"] == status
if order["productid"] == product_id or order["status"] == status
]

@staticmethod
def find_order_by_id(order_id: int):
return Database._orders.get(order_id)

@staticmethod
def find_order_by_id_or_404(order_id: int):
order = Database.find_order_by_id(order_id)
if not order:
abort(404, f"Order with {order_id} was not found")
return order

@staticmethod
def delete_order(order_id: int):
if Database._orders.get(order_id):
Expand All @@ -78,6 +94,6 @@ def add_order(order: Order):

@staticmethod
def update_order(order: Order, new_data: Order):
order["product_id"] = new_data["product_id"]
order["productid"] = new_data["productid"]
order["count"] = new_data["count"]
order["status"] = new_data["status"]
18 changes: 7 additions & 11 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,24 @@ class ProductType(str, enum.Enum):
BOOK = "book"
OTHER = "other"

def __str__(self):
return self.value


class OrderStatus(str, enum.Enum):
FULFILLED = "fulfilled"
PENDING = "pending"
CANCELLED = "cancelled"

def __str__(self):
return self.value

class Id(TypedDict):
id: int


class Product(TypedDict):
class Product(Id):
name: str
product_type: ProductType
type: ProductType
inventory: int
id: int


class Order(TypedDict):
product_id: int
class Order(Id):
productid: int
count: int
status: OrderStatus
id: int
50 changes: 28 additions & 22 deletions api/orders/routes.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,65 @@
from typing import TYPE_CHECKING

from flask import Blueprint, Response, abort, jsonify, request
from flask import Blueprint, Response, jsonify, request

from api.db import Database
from api.schemas import OrderSchema
from api.schemas import IdSchema, NewOrderSchema, OrderSchema

if TYPE_CHECKING:
from api.models import Order
from api.models import Id, Order

orders = Blueprint("orders", __name__, url_prefix="/orders")
order_schema = OrderSchema()
orders_schema = OrderSchema(many=True)
new_order_schema = NewOrderSchema()
id_schema = IdSchema()


@orders.route("/", methods=["GET"])
def get_orders():
args: Order = order_schema.load(request.args, partial=True) # type: ignore[return-value]
productid = request.args.get("productid")
if productid and productid.isdigit():
productid = int(productid)

if not args.get("product_id") and not args.get("status"):
return orders_schema.dump(Database.all_orders())
data = request.args | {"productid": productid} if productid else {}
args: Order = new_order_schema.load(data, partial=True) # type: ignore[reportAssignmentType]
if args.get("productid") is None and args.get("status") is None:
return order_schema.dump(Database.all_orders(), many=True)

return orders_schema.dump(Database.find_orders(args.get("product_id", 0), args.get("status")))
return order_schema.dump(Database.find_orders(args.get("productid"), args.get("status")), many=True)


@orders.route("/", methods=["POST"])
def add_order():
order: Order = order_schema.load(request.json) # type: ignore[return-value]
order: Order = new_order_schema.load(request.json) # type: ignore[reportAssignmentType]
Database.add_order(order)
return jsonify(id=order["id"])


@orders.route("/<int:id>", methods=["GET"])
def get_order(id: int):
order = Database.find_order_by_id(id)
if not order:
return abort(404, f"Order with {id} was not found")
@orders.route("/<id>", methods=["GET"])
def get_order(id: str): # noqa: A002
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
order = Database.find_order_by_id_or_404(params["id"])
return order_schema.dump(order)


@orders.route("/<int:id>", methods=["POST"])
def update_order(id: int):
order = Database.find_order_by_id(id)
@orders.route("/<id>", methods=["POST"])
def update_order(id: str): # noqa: A002
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
order = Database.find_order_by_id(params["id"])
new_data: Order = order_schema.load(request.json) # type: ignore[reportAssignmentType]
if not order:
# TODO: Temporary 200 Response AS per v3_SPEC, Needs fixing across Node and Python
return Response("success", 200, mimetype="text/plain")
new_data: Order = order_schema.load(request.json) # type: ignore[return-value]
Database.update_order(order, new_data)
return Response("success", 200, mimetype="text/plain")


@orders.route("/<int:id>", methods=["DELETE"])
def delete_order(id: int):
order = Database.find_order_by_id(id)
@orders.route("/<id>", methods=["DELETE"])
def delete_order(id: str): # noqa: A002
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
order = Database.find_order_by_id(params["id"])
if not order:
# TODO: Temporary 200 Response AS per v3_SPEC, Needs fixing across Node and Python
return Response("success", 200, mimetype="text/plain")
Database.delete_order(id)
Database.delete_order(params["id"])
return Response("success", 200, mimetype="text/plain")
49 changes: 26 additions & 23 deletions api/products/routes.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,61 @@
from typing import TYPE_CHECKING

from flask import Blueprint, Response, abort, jsonify, request
from flask import Blueprint, Response, jsonify, request

from api.db import Database
from api.schemas import ProductSchema
from api.schemas import IdSchema, NewProductSchema, ProductSchema

if TYPE_CHECKING:
from api.models import Product
from api.models import Id, Product

products = Blueprint("products", __name__, url_prefix="/products")
product_schema = ProductSchema()
products_schema = ProductSchema(many=True)
prod_schema = ProductSchema()
new_prod_schema = NewProductSchema()
id_schema = IdSchema()


@products.route("/", methods=["GET"])
def get_products():
args: Product = product_schema.load(request.args, partial=True) # type: ignore[return-value]
args: Product = prod_schema.load(request.args, partial=True) # type: ignore[reportAssignmentType]

if not args.get("name") and not args.get("product_type"):
return products_schema.dump(Database.all_products())
return prod_schema.dump(Database.all_products(), many=True)

return products_schema.dump(Database.find_products(args.get("name", ""), args.get("product_type")))
return prod_schema.dump(Database.find_products(args.get("name", ""), args.get("product_type")), many=True)


@products.route("/", methods=["POST"])
def add_product():
product: Product = product_schema.load(request.json) # type: ignore[return-value]
product: Product = new_prod_schema.load(request.json) # type: ignore[reportAssignmentType]
Database.add_product(product)
return jsonify(id=product["id"])


@products.route("/<int:id>", methods=["GET"])
def get_product(id: int):
product = Database.find_product_by_id(id)
if not product:
return abort(404, f"Product with {id} was not found")
return product_schema.dump(product)
@products.route("<id>", methods=["GET"])
def get_product(id: str): # noqa: A002
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
product = Database.find_product_by_id_or_404(params["id"])
return prod_schema.dump(product)


@products.route("/<int:id>", methods=["POST"])
def update_product(id: int):
product = Database.find_product_by_id(id)
@products.route("<id>", methods=["POST"])
def update_product(id: str): # noqa: A002
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
new_data: Product = prod_schema.load(request.json) # type: ignore[reportAssignmentType]
product = Database.find_product_by_id(params["id"])
if not product:
# TODO: Temporary 200 Response AS per v3_SPEC, Needs fixing across Node and Python
return Response("success", 200, mimetype="text/plain")
new_data: Product = product_schema.load(request.json) # type: ignore[return-value]
Database.update_product(product, new_data)
return Response("success", 200, mimetype="text/plain")


@products.route("/<int:id>", methods=["DELETE"])
def delete_product(id: int):
product = Database.find_product_by_id(id)
@products.route("<id>", methods=["DELETE"])
def delete_product(id: str): # noqa: A002F
params: Id = id_schema.load({"id": id}) # type: ignore[reportAssignmentType]
product = Database.find_product_by_id(params["id"])
if not product:
# TODO: Temporary 200 Response AS per v3_SPEC, Needs fixing across Node and Python
return Response("success", 200, mimetype="text/plain")
Database.delete_product(id)
Database.delete_product(params["id"])
return Response("success", 200, mimetype="text/plain")
51 changes: 23 additions & 28 deletions api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
from marshmallow import Schema, fields, post_load, validate
from marshmallow import Schema, fields

from api.models import Order, OrderStatus, Product, ProductType
from api.models import OrderStatus, ProductType

VALID_PRODUCT_TYPES = [t.value for t in ProductType]
VALID_ORDER_STATUS = [s.value for s in OrderStatus]


class ProductSchema(Schema):
id = fields.Integer(required=False, load_default=None)
class NewProductSchema(Schema):
name = fields.String(required=True)
product_type = fields.String(required=True, validate=validate.OneOf(VALID_PRODUCT_TYPES), data_key="type")
inventory = fields.Integer(required=True)

@post_load
def serialize_enum(self, data: Product, **_):
if data.get("product_type"):
data["product_type"] = ProductType(data["product_type"])
return data


class OrderSchema(Schema):
id = fields.Integer(required=False, load_default=None)
product_id = fields.Integer(required=True, data_key="productid")
count = fields.Integer(required=True)
status = fields.String(required=True, validate=validate.OneOf(VALID_ORDER_STATUS))

@post_load
def serialize_enum(self, data: Order, **_):
if data.get("status"):
data["status"] = OrderStatus(data["status"])
return data
type = fields.Enum(ProductType, required=True, by_value=True)
inventory = fields.Integer(required=True, strict=True)


class NewOrderSchema(Schema):
productid = fields.Integer(required=True, strict=True)
count = fields.Integer(required=True, strict=True)
status = fields.Enum(OrderStatus, required=True, by_value=True)


class IdSchema(Schema):
id = fields.Integer(required=True, strict=False)


class ProductSchema(NewProductSchema):
id = fields.Integer(required=True, strict=True)


class OrderSchema(NewOrderSchema):
id = fields.Integer(required=True, strict=True)
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from api import app

if __name__ == '__main__':
app.run(debug=True)
if __name__ == "__main__":
app.run()
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
coverage==7.5.3
Flask==3.0.3
marshmallow==3.21.2
marshmallow==3.21.3
pytest==8.2.2
specmatic==1.3.22
specmatic==1.3.23
9 changes: 2 additions & 7 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ select = ["ALL"]
ignore = [
#### modules
"ANN", # flake8-annotations
"COM", # flake8-commas
"C90", # mccabe complexity
"DJ", # django
"EXE", # flake8-executable
"T10", # debugger
"TID", # flake8-tidy-imports

#### specific rules
"D100", # missing docstring in public module
Expand All @@ -33,5 +27,6 @@ ignore = [
"EM101", # Exception must not use a string literal,
"FBT001", # boolean default value in function signature
"PLR0913", # Too many arguments
"FBT002", # boolean positional only argument
"FBT002", # boolean positional only argument,
"PLR2004", # Magic value used in comparison,
]
Loading

0 comments on commit 252c318

Please sign in to comment.