From b9a1e9020c9e833a36c42c23674d1a6a24dca59d Mon Sep 17 00:00:00 2001 From: Ben Bolte Date: Wed, 15 Jan 2025 11:34:03 -0800 Subject: [PATCH] crud (#670) * crud * creation script * fix tests * robot classes crud * robot router / crud implementation * urdf uploading and downloading * remove environment * fix tests * robot fixes * nit fix --- .github/workflows/test.yml | 1 - Makefile | 57 +--- README.md | 16 + docker/docker-compose-localstack.yml | 9 + scripts/create_db.py | 47 +++ tests/conftest.py | 13 +- tests/routers/test_auth.py | 6 +- tests/routers/test_robot.py | 103 +++++++ tests/routers/test_robot_class.py | 131 ++++++++ www/auth.py | 4 +- www/crud/__init__.py | 38 +++ www/crud/__main__.py | 12 + www/crud/base/db.py | 61 +++- www/crud/base/s3.py | 62 ++++ www/crud/robot.py | 186 +++++++++++ www/crud/robot_class.py | 182 +++++++++++ www/errors.py | 59 +++- www/main.py | 86 +----- www/middleware.py | 26 ++ www/model.py | 445 --------------------------- www/routers/__init__.py | 18 ++ www/routers/auth.py | 2 +- www/routers/robot.py | 153 +++++++++ www/routers/robot_class.py | 165 ++++++++++ www/settings/configs/local.yaml | 2 - www/settings/environment.py | 14 +- www/utils/db.py | 34 ++ 27 files changed, 1340 insertions(+), 592 deletions(-) create mode 100644 docker/docker-compose-localstack.yml create mode 100644 scripts/create_db.py create mode 100644 tests/routers/test_robot.py create mode 100644 tests/routers/test_robot_class.py create mode 100644 www/crud/__main__.py create mode 100644 www/crud/robot.py create mode 100644 www/crud/robot_class.py create mode 100644 www/middleware.py delete mode 100644 www/model.py create mode 100644 www/routers/robot.py create mode 100644 www/routers/robot_class.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4e15a452..94dceebf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,7 +23,6 @@ jobs: run-tests: timeout-minutes: 10 runs-on: ubuntu-latest - environment: ${{ github.ref == 'refs/heads/master' && 'production' || 'staging' }} steps: - name: Check out repository diff --git a/Makefile b/Makefile index 7efb662b..ea6f3ef6 100644 --- a/Makefile +++ b/Makefile @@ -1,63 +1,38 @@ # Makefile -# ------------------------ # -# Serve # -# ------------------------ # +ENV_VARS = \ + ENVIRONMENT=local \ + AWS_REGION=us-east-1 \ + AWS_DEFAULT_REGION=us-east-1 \ + AWS_ACCESS_KEY_ID=test \ + AWS_SECRET_ACCESS_KEY=test \ + AWS_ENDPOINT_URL=http://localhost:4566 start: - @if [ -f env.sh ]; then source env.sh; fi; fastapi dev 'www/main.py' --host localhost --port 8080 + $(ENV_VARS) fastapi dev 'www/main.py' --host localhost --port 8080 .PHONY: start -start-docker-dynamodb: - @docker kill www-db || true - @docker rm www-db || true - @docker run --name www-db -d -p 8000:8000 amazon/dynamodb-local -.PHONY: start-docker-dynamodb +start-localstack: + @docker compose --file docker/docker-compose-localstack.yml down --remove-orphans + @docker rm -f www-localstack 2>/dev/null || true + @docker compose --file docker/docker-compose-localstack.yml up -d localstack --force-recreate +.PHONY: start-localstack -start-docker-backend: - @docker kill www-backend || true - @docker rm www-backend || true - @docker build -t www-backend . - @docker run --name www-backend -d -p 8080:8080 www-backend -.PHONY: start-docker-backend - -start-docker-localstack: - @docker kill www-localstack || true - @docker rm www-localstack || true - @docker run -d --name www-localstack -p 4566:4566 -p 4571:4571 localstack/localstack -.PHONY: start-docker-localstack - -# ------------------------ # -# Install # -# ------------------------ # - -install: - @pip install -e '[.dev]' -.PHONY: install - -# ------------------------ # -# Code Formatting # -# ------------------------ # +create-db: + $(ENV_VARS) python -m scripts.create_db --s3 --db +.PHONY: create-db format: @black www tests @ruff check --fix www tests .PHONY: format -# ------------------------ # -# Static Checks # -# ------------------------ # - lint: @black --diff --check www tests @ruff check www tests @mypy --install-types --non-interactive www tests .PHONY: lint -# ------------------------ # -# Unit tests # -# ------------------------ # - test: @python -m pytest .PHONY: test-backend diff --git a/README.md b/README.md index 3a0916cc..227bae90 100644 --- a/README.md +++ b/README.md @@ -15,3 +15,19 @@ # K-Scale Website This is the codebase for K-Scale's web infrastructure. + +## Getting Started + +First, pull the repository and install the project: + +```bash +git clone https://github.com/kscalelabs/www.git +cd www +pip install -e '.[dev]' +``` + +Next, start localstack: + +```bash +make start-localstack +``` diff --git a/docker/docker-compose-localstack.yml b/docker/docker-compose-localstack.yml new file mode 100644 index 00000000..2ee26ed3 --- /dev/null +++ b/docker/docker-compose-localstack.yml @@ -0,0 +1,9 @@ +services: + localstack: + image: localstack/localstack + container_name: www-localstack + ports: + - "4566:4566" + - "4571:4571" + restart: always + pull_policy: always diff --git a/scripts/create_db.py b/scripts/create_db.py new file mode 100644 index 00000000..21aa4dbe --- /dev/null +++ b/scripts/create_db.py @@ -0,0 +1,47 @@ +"""Defines the script to create the database. + +This script is meant to be run locally, to create the initial database tables +in localstack. To run it, use: + +```bash +python -m scripts.create_db --s3 +``` +""" + +import argparse +import asyncio +import logging + +import colorlogging + +from www.crud.base.db import DBCrud +from www.crud.base.s3 import create_s3_bucket +from www.crud.robot import robot_crud +from www.crud.robot_class import robot_class_crud + +logger = logging.getLogger(__name__) + +CRUDS: list[DBCrud] = [robot_class_crud, robot_crud] + + +async def main() -> None: + colorlogging.configure() + + parser = argparse.ArgumentParser() + parser.add_argument("--s3", action="store_true", help="Create the S3 bucket.") + parser.add_argument("--db", action="store_true", help="Create the DynamoDB tables.") + args = parser.parse_args() + + if args.s3: + logger.info("Creating S3 bucket...") + await create_s3_bucket() + + if args.db: + for crud in CRUDS: + async with crud: + logger.info("Creating %s table...", crud.table_name) + await crud.create_table() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 54ab8b37..78ee9ca9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ from moto.server import ThreadedMotoServer from pytest_mock.plugin import AsyncMockType, MockerFixture, MockType +from www.crud import create + os.environ["ENVIRONMENT"] = "local" @@ -20,13 +22,19 @@ def pytest_collection_modifyitems(items: list[Function]) -> None: @pytest.fixture(autouse=True) -def mock_aws() -> Generator[None, None, None]: +async def mock_aws() -> AsyncGenerator[None, None]: server: ThreadedMotoServer | None = None # logging.getLogger("botocore").setLevel(logging.DEBUG) logging.getLogger("botocore").setLevel(logging.WARN) try: + # Sets required AWS environment variables. + os.environ["AWS_ACCESS_KEY_ID"] = "test" + os.environ["AWS_SECRET_ACCESS_KEY"] = "test" + os.environ["AWS_SESSION_TOKEN"] = "test" + os.environ["AWS_DEFAULT_REGION"] = os.environ["AWS_REGION"] = "us-east-1" + # Starts a local AWS server. server = ThreadedMotoServer(port=0) server.start() @@ -36,6 +44,9 @@ def mock_aws() -> Generator[None, None, None]: os.environ["AWS_ENDPOINT_URL_DYNAMODB"] = endpoint os.environ["AWS_ENDPOINT_URL_S3"] = endpoint + # Create the S3 bucket and DynamoDB tables. + await create() + yield finally: diff --git a/tests/routers/test_auth.py b/tests/routers/test_auth.py index aaa2cfdd..57c7720a 100644 --- a/tests/routers/test_auth.py +++ b/tests/routers/test_auth.py @@ -4,6 +4,8 @@ from fastapi import status from fastapi.testclient import TestClient +HEADERS = {"Authorization": "Bearer test"} + @pytest.mark.asyncio async def test_user_endpoint(test_client: TestClient) -> None: @@ -27,7 +29,7 @@ async def test_profile_endpoint(test_client: TestClient) -> None: response = test_client.get("/auth/profile") assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.text - response = test_client.get("/auth/profile", headers={"Authorization": "Bearer test"}) + response = test_client.get("/auth/profile", headers=HEADERS) assert response.status_code == status.HTTP_200_OK, response.text # Matches test user data. @@ -45,5 +47,5 @@ async def test_profile_endpoint(test_client: TestClient) -> None: @pytest.mark.asyncio async def test_logout_endpoint(test_client: TestClient) -> None: - response = test_client.get("/auth/logout", headers={"Authorization": "Bearer test"}) + response = test_client.get("/auth/logout", headers=HEADERS) assert response.status_code == status.HTTP_200_OK, response.text diff --git a/tests/routers/test_robot.py b/tests/routers/test_robot.py new file mode 100644 index 00000000..35c0d525 --- /dev/null +++ b/tests/routers/test_robot.py @@ -0,0 +1,103 @@ +"""Unit tests for the robot router.""" + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +HEADERS = {"Authorization": "Bearer test"} + + +@pytest.mark.asyncio +async def test_robots(test_client: TestClient) -> None: + # First create a robot class that we'll use + response = test_client.put("/robots/add", params={"class_name": "test_class"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + robot_class_data = response.json() + assert robot_class_data["id"] is not None + + # Adds a robot + response = test_client.put( + "/robot/add", params={"robot_name": "test_robot", "class_name": "test_class"}, headers=HEADERS + ) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + robot_id = data["id"] + assert robot_id is not None + assert data["robot_name"] == "test_robot" + assert data["class_name"] == "test_class" + + # Attempts to add a second robot with the same name + response = test_client.put( + "/robot/add", params={"robot_name": "test_robot", "class_name": "test_class"}, headers=HEADERS + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text + + # Gets all robots + response = test_client.get("/robot", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert len(data) == 1 + assert data[0]["robot_name"] == "test_robot" + + # Gets the robot by name + response = test_client.get("/robot/name/test_robot", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["robot_name"] == "test_robot" + + # Gets the robot by ID + response = test_client.get(f"/robot/id/{robot_id}", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["robot_name"] == "test_robot" + + # Adds a second robot + response = test_client.put( + "/robot/add", params={"robot_name": "other_robot", "class_name": "test_class"}, headers=HEADERS + ) + assert response.status_code == status.HTTP_200_OK, response.text + + # Updates the first robot + response = test_client.put( + "/robot/update", + params={ + "robot_name": "test_robot", + "new_robot_name": "updated_robot", + "new_description": "new description", + }, + headers=HEADERS, + ) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["robot_name"] == "updated_robot" + assert data["description"] == "new description" + + # Lists my robots + response = test_client.get("/robot/user/me", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert len(data) == 2 + assert all(robot["robot_name"] in ("updated_robot", "other_robot") for robot in data) + + # Deletes the robots + response = test_client.delete("/robot/delete", params={"robot_name": "updated_robot"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Lists my robots again + response = test_client.get("/robot/user/me", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert len(data) == 1 + assert data[0]["robot_name"] == "other_robot" + + # Clean up - delete remaining robot and robot class + response = test_client.get("/robot/name/other_robot", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["id"] is not None + + response = test_client.delete("/robot/delete", params={"robot_name": "other_robot"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + response = test_client.delete("/robots/delete", params={"class_name": "test_class"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text diff --git a/tests/routers/test_robot_class.py b/tests/routers/test_robot_class.py new file mode 100644 index 00000000..fe778556 --- /dev/null +++ b/tests/routers/test_robot_class.py @@ -0,0 +1,131 @@ +"""Unit tests for the robot class router.""" + +import hashlib + +import httpx +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +HEADERS = {"Authorization": "Bearer test"} + + +@pytest.mark.asyncio +async def test_robot_classes(test_client: TestClient) -> None: + # Adds a robot class. + response = test_client.put("/robots/add", params={"class_name": "test"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Attempts to add a second robot class with the same name. + response = test_client.put("/robots/add", params={"class_name": "test"}, headers=HEADERS) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text + + # Gets the added robot class. + response = test_client.get("/robots", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert len(data) == 1 + assert data[0]["class_name"] == "test" + + # Gets the robot class by name. + response = test_client.get("/robots/name/test", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["class_name"] == "test" + + # Adds a second robot class. + response = test_client.put("/robots/add", params={"class_name": "othertest"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Updates the robot class. + response = test_client.put( + "/robots/update", + params={ + "class_name": "test", + "new_class_name": "othertest", + "new_description": "new description", + }, + headers=HEADERS, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.text + + # Updates the robot class. + response = test_client.put( + "/robots/update", + params={ + "class_name": "test", + "new_class_name": "newtest", + "new_description": "new description", + }, + headers=HEADERS, + ) + assert response.status_code == status.HTTP_200_OK, response.text + + # Lists my robot classes. + response = test_client.get("/robots/user/me", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert len(data) == 2 + assert all(robot_class["class_name"] in ("newtest", "othertest") for robot_class in data) + + # Deletes the robot classes. + response = test_client.delete("/robots/delete", params={"class_name": "newtest"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + response = test_client.delete("/robots/delete", params={"class_name": "othertest"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Lists my robot classes. + response = test_client.get("/robots/user/me", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + assert response.json() == [] + + +@pytest.mark.asyncio +async def test_urdf(test_client: TestClient) -> None: + # Adds a robot class. + response = test_client.put("/robots/add", params={"class_name": "test"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Uploads a URDF for the robot class. + response = test_client.put( + "/robots/urdf/test", + params={ + "filename": "robot.urdf", + "content_type": "application/gzip", + }, + headers=HEADERS, + ) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["url"] is not None + + # Uploads a URDF for the robot class. + async with httpx.AsyncClient() as client: + response = await client.put( + url=data["url"], + files={"file": ("robot.urdf", b"test", data["content_type"])}, + headers={"Content-Type": data["content_type"]}, + ) + assert response.status_code == status.HTTP_200_OK, response.text + + # Gets the URDF for the robot class. + response = test_client.get("/robots/urdf/test", headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + data = response.json() + assert data["url"] is not None + + # Downloads the URDF from the presigned URL. + async with httpx.AsyncClient() as client: + response = await client.get(data["url"]) + assert response.status_code == status.HTTP_200_OK, response.text + content = await response.aread() + assert data["md5_hash"] == f'"{hashlib.md5(content).hexdigest()}"' + + # Deletes the robot classes. + response = test_client.delete("/robots/delete", params={"class_name": "test"}, headers=HEADERS) + assert response.status_code == status.HTTP_200_OK, response.text + + # Check that the URDF is deleted. + response = test_client.get("/robots/urdf/test", headers=HEADERS) + assert response.status_code == status.HTTP_404_NOT_FOUND, response.text diff --git a/www/auth.py b/www/auth.py index c674e319..2e6e47a3 100644 --- a/www/auth.py +++ b/www/auth.py @@ -69,13 +69,13 @@ def _decode_user_from_token(token: str) -> User: except Exception as e: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_401_UNAUTHORIZED, # Return 401 to indicate that the token is invalid detail="Failed to validate token", ) from e if env.site.is_test_environment and not user.can_test: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, + status_code=status.HTTP_403_FORBIDDEN, # Return 403 to indicate that the user does not have permissions detail="User does not have test permissions", ) diff --git a/www/crud/__init__.py b/www/crud/__init__.py index e69de29b..ec5b428a 100644 --- a/www/crud/__init__.py +++ b/www/crud/__init__.py @@ -0,0 +1,38 @@ +"""Defines the common CRUD operations for the application.""" + +__all__ = ["robot_crud", "robot_class_crud", "s3_crud"] + +import asyncio +import logging + +from .base.s3 import s3_crud +from .robot import robot_crud +from .robot_class import robot_class_crud + +logger = logging.getLogger(__name__) + + +async def create_s3_bucket() -> None: + logger.info("Creating S3 bucket...") + async with s3_crud as s3: + await s3.create_bucket() + + +async def create_robot_table() -> None: + logger.info("Creating robot table...") + async with robot_crud as robot: + await robot.create_table() + + +async def create_robot_class_table() -> None: + logger.info("Creating robot class table...") + async with robot_class_crud as robot_class: + await robot_class.create_table() + + +async def create() -> None: + await asyncio.gather( + create_s3_bucket(), + create_robot_table(), + create_robot_class_table(), + ) diff --git a/www/crud/__main__.py b/www/crud/__main__.py new file mode 100644 index 00000000..780c484d --- /dev/null +++ b/www/crud/__main__.py @@ -0,0 +1,12 @@ +"""Runs the CRUD creation functions.""" + +import asyncio + +import colorlogging + +from . import create + +if __name__ == "__main__": + # python -m www.crud + colorlogging.configure() + asyncio.run(create()) diff --git a/www/crud/base/db.py b/www/crud/base/db.py index 2dff58d5..9d11f3ff 100644 --- a/www/crud/base/db.py +++ b/www/crud/base/db.py @@ -2,6 +2,7 @@ import asyncio import functools +import itertools import logging from abc import ABC, abstractmethod from typing import ( @@ -14,10 +15,10 @@ ) import aioboto3 +from botocore.exceptions import ClientError from pydantic import BaseModel from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource, Table -from www.auth import User from www.settings import env logger = logging.getLogger(__name__) @@ -42,10 +43,6 @@ def __init__(self) -> None: def _get_table_name(self) -> str: """Returns the name of the table.""" - @abstractmethod - def delete_user_data(self, user: User) -> None: - """Deletes all data for a user.""" - @property def db(self) -> DynamoDBServiceResource: if self.__db is None: @@ -54,7 +51,7 @@ def db(self) -> DynamoDBServiceResource: @functools.cached_property def table_name(self) -> str: - return f"www-{self._get_table_name()}{env.aws.dynamodb.table_suffix}" + return f"{env.aws.dynamodb.table_prefix}-{self._get_table_name()}" @property async def table(self) -> Table: @@ -96,3 +93,55 @@ async def _get_by_known_id(self, record_id: str) -> dict[str, Any] | None: table = await self.table response = await table.get_item(Key={"id": record_id}) return response.get("Item") + + async def create_table(self) -> None: + try: + await self.db.meta.client.describe_table(TableName=self.table_name) + logger.info("Found existing table %s", self.table_name) + return + + except ClientError: + pass + + logger.info("Creating table %s", self.table_name) + gsis_set = self.get_gsis() + gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] = [ + (f"{g}_index", g, "S", "HASH") for g in gsis_set + ] + keys = self.get_keys() + + if gsis: + table = await self.db.create_table( + TableName=self.table_name, + AttributeDefinitions=[ + {"AttributeName": n, "AttributeType": t} + for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis)) + ], + KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], + GlobalSecondaryIndexes=( + [ + { + "IndexName": i, + "KeySchema": [{"AttributeName": n, "KeyType": t}], + "Projection": {"ProjectionType": "ALL"}, + } + for i, n, _, t in gsis + ] + ), + DeletionProtectionEnabled=env.aws.dynamodb.deletion_protection, + BillingMode="PAY_PER_REQUEST", + ) + + else: + table = await self.db.create_table( + AttributeDefinitions=[ + {"AttributeName": n, "AttributeType": t} for n, t in ((n, t) for (n, t, _) in keys) + ], + TableName=self.table_name, + KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], + DeletionProtectionEnabled=env.aws.dynamodb.deletion_protection, + BillingMode="PAY_PER_REQUEST", + ) + + # Wait for the table to be created. + await table.wait_until_exists() diff --git a/www/crud/base/s3.py b/www/crud/base/s3.py index f048a2d1..ecfb101b 100644 --- a/www/crud/base/s3.py +++ b/www/crud/base/s3.py @@ -31,6 +31,10 @@ def s3(self) -> S3ServiceResource: raise RuntimeError("Must call __aenter__ first!") return self.__s3 + @property + def prefix(self) -> str: + return "" + async def __aenter__(self) -> Self: await super().__aenter__() session = aioboto3.Session() @@ -47,6 +51,32 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # *(resource.__aexit__(exc_type, exc_val, exc_tb) for resource in to_close), ) + async def create_bucket(self) -> None: + try: + await self.s3.meta.client.head_bucket(Bucket=env.aws.s3.bucket) + logger.info("Found existing bucket %s", env.aws.s3.bucket) + return + except ClientError: + pass + + logger.info("Creating bucket %s", env.aws.s3.bucket) + await self.s3.create_bucket(Bucket=env.aws.s3.bucket) + + logger.info("Updating %s CORS configuration", env.aws.s3.bucket) + s3_cors = await self.s3.BucketCors(env.aws.s3.bucket) + await s3_cors.put( + CORSConfiguration={ + "CORSRules": [ + { + "AllowedHeaders": ["*"], + "AllowedMethods": ["GET"], + "AllowedOrigins": ["*"], + "ExposeHeaders": ["ETag"], + } + ] + }, + ) + async def get_file_size(self, filename: str) -> int | None: """Gets the size of a file in S3. @@ -110,6 +140,28 @@ async def delete_from_s3(self, filename: str) -> None: bucket = await self.s3.Bucket(env.aws.s3.bucket) await bucket.delete_objects(Delete={"Objects": [{"Key": f"{env.aws.s3.prefix}{filename}"}]}) + async def get_file_hash(self, filename: str) -> str: + """Gets the hash of a file in S3.""" + bucket = await self.s3.Bucket(env.aws.s3.bucket) + obj = await bucket.Object(f"{env.aws.s3.prefix}{filename}") + data = await obj.get() + return data["ETag"] + + async def generate_presigned_download_url( + self, + s3_key: str, + expires_in: int = 3600, + ) -> str: + """Generates a presigned URL for downloading a file from S3.""" + return await self.s3.meta.client.generate_presigned_url( + ClientMethod="get_object", + Params={ + "Bucket": env.aws.s3.bucket, + "Key": f"{env.aws.s3.prefix}{s3_key}", + }, + ExpiresIn=expires_in, + ) + async def generate_presigned_upload_url( self, filename: str, @@ -152,3 +204,13 @@ async def __call__(self) -> AsyncGenerator[Self, None]: s3_crud = S3Crud() + + +async def create_s3_bucket() -> None: + async with s3_crud as crud: + await crud.create_bucket() + + +if __name__ == "__main__": + # python -m www.crud.s3 + asyncio.run(create_s3_bucket()) diff --git a/www/crud/robot.py b/www/crud/robot.py new file mode 100644 index 00000000..fbb40378 --- /dev/null +++ b/www/crud/robot.py @@ -0,0 +1,186 @@ +"""Defines the CRUD operations for the robot table.""" + +import asyncio +import re + +from boto3.dynamodb.conditions import Key +from pydantic import BaseModel + +from www.crud.base.db import DBCrud, TableKey +from www.errors import InvalidNameError +from www.utils.db import new_uuid + + +class Robot(BaseModel): + """Defines the data structure for a robot.""" + + id: str + robot_name: str + description: str + user_id: str + class_id: str + + +class RobotCrud(DBCrud): + """Defines the table holding information about individual robots.""" + + def _get_table_name(self) -> str: + return "robot" + + @classmethod + def get_keys(cls) -> list[TableKey]: + return [("id", "S", "HASH")] + + @classmethod + def get_gsis(cls) -> set[str]: + return {"robot_name", "user_id", "class_id"} + + def _is_valid_name(self, robot_name: str) -> bool: + return len(robot_name) >= 3 and len(robot_name) < 64 and re.match(r"^[a-zA-Z0-9_-]+$", robot_name) is not None + + def _is_valid_description(self, description: str | None) -> bool: + return description is None or len(description) < 2048 + + async def add_robot( + self, + robot_name: str, + user_id: str, + class_id: str, + description: str | None = None, + ) -> Robot: + """Adds a robot to the database. + + Args: + robot_name: The name of the robot. + user_id: The ID of the user who owns the robot. + class_id: The ID of the robot class that the robot belongs to. + description: The description of the robot. + + Returns: + The robot that was added. + """ + if not self._is_valid_name(robot_name): + raise InvalidNameError(f"Invalid robot name: {robot_name}") + if not self._is_valid_description(description): + raise InvalidNameError("Invalid robot description") + + robot_id = new_uuid() + + robot = Robot( + id=robot_id, + robot_name=robot_name, + description="Empty description" if description is None else description, + user_id=user_id, + class_id=class_id, + ) + + # Check if the robot already exists. + existing_robot = await self.get_robot_by_name(robot_name) + if existing_robot is not None: + raise ValueError(f"Robot with name '{robot_name}' already exists") + + table = await self.table + try: + await table.put_item( + Item=robot.model_dump(), + ConditionExpression="attribute_not_exists(robot_name)", + ) + except table.meta.client.exceptions.ConditionalCheckFailedException: + raise ValueError(f"Robot with name '{robot_name}' already exists") + + return robot + + async def update_robot( + self, + robot: Robot, + new_robot_name: str | None = None, + new_description: str | None = None, + ) -> Robot: + """Updates a robot in the database. + + Args: + robot: The robot to update. + new_robot_name: The new name of the robot. + new_description: The new description of the robot. + + Returns: + The robot that was updated. + """ + if new_robot_name is not None and not self._is_valid_name(new_robot_name): + raise InvalidNameError(f"Invalid robot name: {new_robot_name}") + if new_description is not None and not self._is_valid_description(new_description): + raise InvalidNameError("Invalid robot description") + + table = await self.table + + # Populates values. + old_robot_name = robot.robot_name + if new_robot_name is not None: + if old_robot_name != new_robot_name: + if (await self.get_robot_by_name(new_robot_name)) is not None: + raise ValueError(f"Robot with name '{new_robot_name}' already exists") + robot.robot_name = new_robot_name + + if new_description is not None: + robot.description = new_description + + try: + await table.update_item( + Key={"id": robot.id}, + UpdateExpression="SET robot_name = :new_robot_name, description = :new_description", + ConditionExpression="attribute_not_exists(robot_name) OR robot_name = :old_robot_name", + ExpressionAttributeValues={ + ":new_robot_name": robot.robot_name, + ":new_description": robot.description, + ":old_robot_name": old_robot_name, + }, + ) + except table.meta.client.exceptions.ConditionalCheckFailedException: + raise ValueError(f"Robot with name '{new_robot_name}' already exists") + + return robot + + async def delete_robot(self, robot: Robot) -> None: + """Deletes a robot from the database.""" + table = await self.table + await table.delete_item(Key={"id": robot.id}) + + async def get_robot_by_name(self, robot_name: str) -> Robot | None: + """Gets a robot by name.""" + table = await self.table + response = await table.query( + IndexName=self.get_gsi_index_name("robot_name"), + KeyConditionExpression=Key("robot_name").eq(robot_name), + ) + if (items := response.get("Items", [])) == []: + return None + if len(items) > 1: + raise ValueError(f"Multiple robots with name '{robot_name}' found") + return Robot.model_validate(items[0]) + + async def get_robot_by_id(self, id: str) -> Robot | None: + """Gets a robot by ID.""" + table = await self.table + response = await table.get_item(Key={"id": id}) + if (item := response.get("Item")) is None: + return None + return Robot.model_validate(item) + + async def list_robots(self, user_id: str | None = None) -> list[Robot]: + """Gets all robots.""" + table = await self.table + if user_id is not None: + response = await table.query( + IndexName=self.get_gsi_index_name("user_id"), + KeyConditionExpression=Key("user_id").eq(user_id), + ) + else: + response = await table.scan() + return [Robot.model_validate(item) for item in response.get("Items", [])] + + +robot_crud = RobotCrud() + +if __name__ == "__main__": + # python -m www.crud.robot + asyncio.run(robot_crud.create_table()) diff --git a/www/crud/robot_class.py b/www/crud/robot_class.py new file mode 100644 index 00000000..aaa08852 --- /dev/null +++ b/www/crud/robot_class.py @@ -0,0 +1,182 @@ +"""Defines the CRUD operations for the robot-class table.""" + +import asyncio +import re + +from boto3.dynamodb.conditions import Key +from pydantic import BaseModel + +from www.crud.base.db import DBCrud, TableKey +from www.errors import InvalidNameError +from www.utils.db import new_uuid + + +class RobotClass(BaseModel): + """Defines the data structure for a robot class.""" + + id: str + class_name: str + description: str + user_id: str + + +class RobotClassCrud(DBCrud): + """Defines the table holding information about classes of robots.""" + + def _get_table_name(self) -> str: + return "robot-class" + + @classmethod + def get_keys(cls) -> list[TableKey]: + return [("id", "S", "HASH")] + + @classmethod + def get_gsis(cls) -> set[str]: + return {"class_name", "user_id"} + + def _is_valid_name(self, class_name: str) -> bool: + return len(class_name) >= 3 and len(class_name) < 64 and re.match(r"^[a-zA-Z0-9_-]+$", class_name) is not None + + def _is_valid_description(self, description: str | None) -> bool: + return description is None or len(description) < 2048 + + async def add_robot_class( + self, + class_name: str, + user_id: str, + description: str | None = None, + ) -> RobotClass: + """Adds a robot class to the database. + + Args: + class_name: The unique robot class name. + user_id: The ID of the user who owns the robot class. + description: The description of the robot class. + + Returns: + The robot class that was added. + """ + if not self._is_valid_name(class_name): + raise InvalidNameError(f"Invalid robot class name: {class_name}") + if not self._is_valid_description(description): + raise InvalidNameError("Invalid robot class description") + + robot_class_id = new_uuid() + + robot_class = RobotClass( + id=robot_class_id, + class_name=class_name, + description="Empty description" if description is None else description, + user_id=user_id, + ) + + # Check if the robot class already exists. + existing_robot_class = await self.get_robot_class_by_name(class_name) + if existing_robot_class is not None: + raise ValueError(f"Robot class with name '{class_name}' already exists") + + table = await self.table + try: + await table.put_item( + Item=robot_class.model_dump(), + ConditionExpression="attribute_not_exists(class_name)", + ) + except table.meta.client.exceptions.ConditionalCheckFailedException: + raise ValueError(f"Robot class with name '{class_name}' already exists") + + return robot_class + + async def update_robot_class( + self, + robot_class: RobotClass, + new_class_name: str | None = None, + new_description: str | None = None, + ) -> RobotClass: + """Updates a robot class in the database. + + Args: + robot_class: The robot class to update. + new_class_name: The new name of the robot class. + new_description: The new description of the robot class. + + Returns: + The robot class that was updated. + """ + if new_class_name is not None and not self._is_valid_name(new_class_name): + raise InvalidNameError(f"Invalid robot class name: {new_class_name}") + if new_description is not None and not self._is_valid_description(new_description): + raise InvalidNameError("Invalid robot class description") + + table = await self.table + + # Populates values. + old_class_name = robot_class.class_name + if new_class_name is not None: + if old_class_name != new_class_name: + if (await self.get_robot_class_by_name(new_class_name)) is not None: + raise ValueError(f"Robot class with name '{new_class_name}' already exists") + robot_class.class_name = new_class_name + + if new_description is not None: + robot_class.description = new_description + + try: + await table.update_item( + Key={"id": robot_class.id}, + UpdateExpression=("SET class_name = :new_class_name, description = :new_description"), + ConditionExpression="attribute_not_exists(class_name) OR class_name = :old_class_name", + ExpressionAttributeValues={ + ":new_class_name": robot_class.class_name, + ":new_description": robot_class.description, + ":old_class_name": old_class_name, + }, + ) + except table.meta.client.exceptions.ConditionalCheckFailedException: + raise ValueError(f"Robot class with name '{new_class_name}' already exists") + + return robot_class + + async def delete_robot_class(self, robot_class: RobotClass) -> None: + """Deletes a robot class from the database.""" + table = await self.table + await table.delete_item(Key={"id": robot_class.id}) + + async def get_robot_class_by_name(self, class_name: str) -> RobotClass | None: + """Gets a robot class by name.""" + table = await self.table + response = await table.query( + IndexName=self.get_gsi_index_name("class_name"), + KeyConditionExpression=Key("class_name").eq(class_name), + ) + if (items := response.get("Items", [])) == []: + return None + if len(items) > 1: + raise ValueError(f"Multiple robot classes with name '{class_name}' found") + return RobotClass.model_validate(items[0]) + + async def get_robot_class_by_id(self, id: str) -> RobotClass | None: + """Gets a robot class by ID.""" + table = await self.table + response = await table.get_item(Key={"id": id}) + if (item := response.get("Item")) is None: + return None + return RobotClass.model_validate(item) + + async def list_robot_classes(self, user_id: str | None = None) -> list[RobotClass]: + """Gets all robot classes.""" + table = await self.table + if user_id is not None: + response = await table.query( + IndexName=self.get_gsi_index_name("user_id"), + KeyConditionExpression=Key("user_id").eq(user_id), + ) + else: + response = await table.scan() + return [RobotClass.model_validate(item) for item in response.get("Items", [])] + + +robot_class_crud = RobotClassCrud() + +if __name__ == "__main__": + # python -m www.crud.robot_class + asyncio.run(robot_class_crud.create_table()) diff --git a/www/errors.py b/www/errors.py index 5d4a7d6d..75100d7d 100644 --- a/www/errors.py +++ b/www/errors.py @@ -1,16 +1,65 @@ """Defines common errors used by the application.""" +from fastapi import FastAPI, Request, status +from fastapi.responses import JSONResponse -class NotAuthenticatedError(Exception): ... +from www.settings import env -class NotAuthorizedError(Exception): ... +class ItemNotFoundError(ValueError): ... -class ItemNotFoundError(ValueError): ... +class ActionNotAllowedError(ValueError): ... + + +class InvalidNameError(ValueError): ... + + +def add_exception_handlers(app: FastAPI) -> None: + """Adds the handlers to the FastAPI app.""" + show_full_error = env.site.is_test_environment + + def protected_str(exc: Exception) -> str: + if show_full_error: + return str(exc) + return "The request was invalid." + + async def value_error_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"message": "The request was invalid.", "detail": protected_str(exc)}, + ) + + app.add_exception_handler(ValueError, value_error_exception_handler) + + async def runtime_error_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"message": "An internal error occurred.", "detail": protected_str(exc)}, + ) + + app.add_exception_handler(RuntimeError, runtime_error_exception_handler) + + async def item_not_found_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"message": "Item not found.", "detail": str(exc)}, + ) + + app.add_exception_handler(ItemNotFoundError, item_not_found_exception_handler) + async def action_not_allowed_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"message": "Action not allowed.", "detail": str(exc)}, + ) -class InternalError(RuntimeError): ... + app.add_exception_handler(ActionNotAllowedError, action_not_allowed_exception_handler) + async def invalid_name_exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"message": "Invalid name.", "detail": str(exc)}, + ) -class BadArtifactError(Exception): ... + app.add_exception_handler(InvalidNameError, invalid_name_exception_handler) diff --git a/www/main.py b/www/main.py index 2c43a910..931c1504 100644 --- a/www/main.py +++ b/www/main.py @@ -1,21 +1,12 @@ """Defines the main entrypoint for the FastAPI app.""" import uvicorn -from fastapi import FastAPI, Request, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse -from starlette.middleware.sessions import SessionMiddleware +from fastapi import FastAPI from www.auth import COGNITO_CLIENT_ID -from www.errors import ( - BadArtifactError, - InternalError, - ItemNotFoundError, - NotAuthenticatedError, - NotAuthorizedError, -) -from www.routers.auth import router as auth_router -from www.settings import env +from www.errors import add_exception_handlers +from www.middleware import add_middleware +from www.routers import add_routers app = FastAPI( title="K-Scale", @@ -30,72 +21,9 @@ }, ) -# Adds CORS middleware. -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Add authentication middleware. -app.add_middleware( - SessionMiddleware, - secret_key=env.middleware.secret_key, - max_age=24 * 60 * 60, # 1 day -) - - -@app.exception_handler(ValueError) -async def value_error_exception_handler(request: Request, exc: ValueError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"message": "The request was invalid.", "detail": str(exc)}, - ) - - -@app.exception_handler(ItemNotFoundError) -async def item_not_found_exception_handler(request: Request, exc: ItemNotFoundError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"message": "Item not found.", "detail": str(exc)}, - ) - - -@app.exception_handler(InternalError) -async def internal_error_exception_handler(request: Request, exc: InternalError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"message": "Internal error.", "detail": str(exc)}, - ) - - -@app.exception_handler(NotAuthenticatedError) -async def not_authenticated_exception_handler(request: Request, exc: NotAuthenticatedError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"message": "Not authenticated.", "detail": str(exc)}, - ) - - -@app.exception_handler(NotAuthorizedError) -async def not_authorized_exception_handler(request: Request, exc: NotAuthorizedError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={"message": "Not authorized.", "detail": str(exc)}, - ) - - -@app.exception_handler(BadArtifactError) -async def bad_artifact_exception_handler(request: Request, exc: BadArtifactError) -> JSONResponse: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"message": f"Bad artifact: {exc}", "detail": str(exc)}, - ) - - -app.include_router(auth_router, prefix="/auth", tags=["auth"]) +add_middleware(app) +add_exception_handlers(app) +add_routers(app) # For running with debugger diff --git a/www/middleware.py b/www/middleware.py new file mode 100644 index 00000000..b030fe83 --- /dev/null +++ b/www/middleware.py @@ -0,0 +1,26 @@ +"""Defines the middleware for the FastAPI app.""" + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from starlette.middleware.sessions import SessionMiddleware + +from www.settings import env + + +def add_middleware(app: FastAPI) -> None: + """Adds the middleware to the FastAPI app.""" + # Adds CORS middleware. + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Add authentication middleware. + app.add_middleware( + SessionMiddleware, + secret_key=env.middleware.secret_key, + max_age=24 * 60 * 60, # 1 day + ) diff --git a/www/model.py b/www/model.py deleted file mode 100644 index d6b7927d..00000000 --- a/www/model.py +++ /dev/null @@ -1,445 +0,0 @@ -"""Defines the table models for the API. - -These correspond directly with the rows in our database, and provide helper -methods for converting from our input data into the format the database -expects (for example, converting a UUID into a string). -""" - -import time -from typing import Literal, Self, cast, get_args - -from pydantic import BaseModel - -from www.auth import User -from www.errors import InternalError -from www.settings import env -from www.utils.db import new_uuid - - -class StoreBaseModel(BaseModel): - """Defines the base model for store database rows. - - Our database architecture uses a single table with a single primary key - (the `id` field). This class provides a common interface for all models - that are stored in the database. - """ - - id: str - - -ArtifactSize = Literal["small", "large"] - -ImageArtifactType = Literal["image"] -XMLArtifactType = Literal["urdf", "mjcf"] -MeshArtifactType = Literal["stl", "obj", "dae", "ply"] -CompressedArtifactType = Literal["tgz", "zip"] -KernelArtifactType = Literal["kernel"] -ArtifactType = ImageArtifactType | KernelArtifactType | XMLArtifactType | MeshArtifactType | CompressedArtifactType - -UPLOAD_CONTENT_TYPE_OPTIONS: dict[ArtifactType, set[str]] = { - # Image - "image": {"image/png", "image/jpeg", "image/jpg", "image/gif", "image/webp"}, - # XML - "urdf": {"application/octet-stream", "text/xml", "application/xml"}, - "mjcf": {"application/octet-stream", "text/xml", "application/xml"}, - # Meshes - "stl": {"application/octet-stream", "text/plain"}, - "obj": {"application/octet-stream", "text/plain"}, - "dae": {"application/octet-stream", "text/plain"}, - "ply": {"application/octet-stream", "text/plain"}, - # Compressed - "tgz": { - "application/gzip", - "application/x-gzip", - "application/x-tar", - "application/x-compressed-tar", - }, - "zip": {"application/zip"}, -} - -DOWNLOAD_CONTENT_TYPE: dict[ArtifactType, str] = { - # Image - "image": "image/png", - # XML - "urdf": "application/octet-stream", - "mjcf": "application/octet-stream", - # Binary - "stl": "application/octet-stream", - "obj": "application/octet-stream", - "dae": "application/octet-stream", - "ply": "application/octet-stream", - # Compressed - "tgz": "application/gzip", - "zip": "application/zip", - "kernel": "application/octet-stream", -} - -SizeMapping: dict[ArtifactSize, tuple[int, int]] = { - "large": env.artifact.large_image_size, - "small": env.artifact.small_image_size, -} - - -def get_artifact_type(content_type: str | None, filename: str | None) -> ArtifactType: - if filename is not None: - extension = filename.split(".")[-1].lower() - if extension == "img": - return "kernel" - if extension in ("png", "jpeg", "jpg", "gif", "webp"): - return "image" - if extension in ("urdf",): - return "urdf" - if extension in ("mjcf", "xml"): - return "mjcf" - if extension in ("stl",): - return "stl" - if extension in ("obj",): - return "obj" - if extension in ("dae",): - return "dae" - if extension in ("ply",): - return "ply" - if extension in ("tgz", "tar.gz"): - return "tgz" - if extension in ("zip",): - return "zip" - - # Attempts to determine from content type. - if content_type is not None: - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["kernel"]: - return "kernel" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["image"]: - return "image" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["urdf"]: - return "urdf" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["mjcf"]: - return "mjcf" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["stl"]: - return "stl" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["obj"]: - return "obj" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["dae"]: - return "dae" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["ply"]: - return "ply" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["tgz"]: - return "tgz" - if content_type in UPLOAD_CONTENT_TYPE_OPTIONS["zip"]: - return "zip" - - raise ValueError(f"Unknown content type for file: {filename}") - - -def get_compression_type(content_type: str | None, filename: str | None) -> CompressedArtifactType: - if filename is None: - raise ValueError("Filename must be provided") - - artifact_type = get_artifact_type(content_type, filename) - if artifact_type not in (allowed_types := get_args(CompressedArtifactType)): - raise ValueError(f"Artifact type {artifact_type} is not compressed; expected one of {allowed_types}") - return cast(CompressedArtifactType, artifact_type) - - -def check_content_type(content_type: str | None, artifact_type: ArtifactType) -> None: - """Checks that the content type is valid for the artifact type. - - Args: - content_type: The content type of the artifact. - artifact_type: The type of the artifact. - - Raises: - ValueError: If the content type is not valid for the artifact type. - """ - if content_type is None: - raise ValueError("Artifact content type was not provided") - if content_type not in UPLOAD_CONTENT_TYPE_OPTIONS[artifact_type]: - content_type_options_string = ", ".join(UPLOAD_CONTENT_TYPE_OPTIONS[artifact_type]) - raise ValueError(f"Invalid content type for artifact; {content_type} not in [{content_type_options_string}]") - - -def get_content_type(artifact_type: ArtifactType) -> str: - return DOWNLOAD_CONTENT_TYPE[artifact_type] - - -class Artifact(StoreBaseModel): - """Defines an artifact that some user owns, like an image or uploaded file. - - Artifacts are stored in S3 and are accessible through CloudFront. - - Artifacts are associated to a given user and can come in different sizes; - for example, the same image may have multiple possible sizes available. - """ - - user_id: str - listing_id: str - name: str - artifact_type: ArtifactType - sizes: list[ArtifactSize] | None = None - description: str | None = None - timestamp: int - children: list[str] | None = None - is_main: bool = False - - @classmethod - def create( - cls, - user_id: str, - listing_id: str, - name: str, - artifact_type: ArtifactType, - sizes: list[ArtifactSize] | None = None, - description: str | None = None, - children: list[str] | None = None, - is_main: bool = False, - ) -> Self: - return cls( - id=new_uuid(), - user_id=user_id, - listing_id=listing_id, - name=name, - artifact_type=artifact_type, - sizes=sizes, - description=description, - timestamp=int(time.time()), - children=children, - is_main=is_main, - ) - - -class Listing(StoreBaseModel): - """Defines a recursively-defined listing. - - Listings can have sub-listings with their component parts. They can also - have associated user-uploaded artifacts like images and URDFs. - """ - - user_id: str - created_at: int - updated_at: int - name: str - slug: str - child_ids: list[str] - description: str | None = None - onshape_url: str | None = None - views: int = 0 - score: int = 0 - - @classmethod - def create( - cls, - user_id: str, - name: str, - slug: str, - child_ids: list[str], - description: str | None = None, - onshape_url: str | None = None, - ) -> Self: - return cls( - id=new_uuid(), - user_id=user_id, - created_at=int(time.time()), - updated_at=int(time.time()), - name=name, - slug=slug, - child_ids=child_ids, - description=description, - onshape_url=onshape_url, - views=0, - score=0, - ) - - -class ListingTag(StoreBaseModel): - """Marks a listing as having a given tag. - - This is useful for tagging listings with metadata, like "robot", "gripper", - or "actuator". Tags are used to categorize listings and make them easier to - search for. - """ - - listing_id: str - name: str - - @classmethod - def create(cls, listing_id: str, tag: str) -> Self: - return cls( - id=new_uuid(), - listing_id=listing_id, - name=tag, - ) - - -def get_artifact_name( - *, - artifact: Artifact | None = None, - artifact_id: str | None = None, - listing_id: str | None = None, - name: str | None = None, - artifact_type: ArtifactType | None = None, - size: ArtifactSize = "large", -) -> str: - if artifact: - listing_id = artifact.listing_id - name = artifact.name - artifact_type = artifact.artifact_type - artifact_id = artifact.id - elif not listing_id or not name or not artifact_type or not artifact_id: - raise InternalError("Must provide artifact or listing_id, name, and artifact_type") - - match artifact_type: - case "image": - height, width = SizeMapping[size] - return f"{listing_id}/{artifact_id}/{size}_{height}x{width}_{name}" - case "kernel" | "urdf" | "mjcf" | "stl" | "obj" | "ply" | "dae" | "zip" | "tgz": - return f"{listing_id}/{artifact_id}/{name}" - case _: - raise ValueError(f"Unknown artifact type: {artifact_type}") - - -def get_artifact_url( - *, - artifact: Artifact | None = None, - artifact_id: str | None = None, - artifact_type: ArtifactType | None = None, - listing_id: str | None = None, - name: str | None = None, - size: ArtifactSize = "large", -) -> str: - artifact_name = get_artifact_name( - artifact=artifact, - artifact_id=artifact_id, - listing_id=listing_id, - name=name, - artifact_type=artifact_type, - size=size, - ) - return f"{env.site.artifact_base_url}{artifact_name}" - - -def get_artifact_urls( - artifact: Artifact | None = None, - artifact_type: ArtifactType | None = None, - listing_id: str | None = None, - name: str | None = None, -) -> dict[ArtifactSize, str]: - return { - size: get_artifact_url( - artifact=artifact, - artifact_type=artifact_type, - listing_id=listing_id, - name=name, - size=size, - ) - for size in SizeMapping.keys() - } - - -async def can_write_artifact(user: User, artifact: Artifact) -> bool: - if user.is_admin: - return True - if user.id == artifact.user_id: - return True - return False - - -async def can_write_listing(user: User, listing: Listing) -> bool: - if user.is_admin: - return True - if user.id == listing.user_id: - return True - return False - - -async def can_read_artifact(user: User, artifact: Artifact) -> bool: - # For now, all users can read all artifacts. In the future we might change - # this so that users can hide their artifacts. - return True - - -async def can_read_listing(user: User, listing: Listing) -> bool: - # For now, all users can read all listings. In the future we might change - # this so that users can hide their listings. - return True - - -class ListingVote(StoreBaseModel): - """Tracks user votes on listings.""" - - user_id: str - listing_id: str - is_upvote: bool - created_at: int - - @classmethod - def create(cls, user_id: str, listing_id: str, is_upvote: bool) -> Self: - return cls( - id=new_uuid(), - user_id=user_id, - listing_id=listing_id, - is_upvote=is_upvote, - created_at=int(time.time()), - ) - - -class Robot(StoreBaseModel): - """User registered robots. Associated with a robot listing. - - Will eventually used for teleop and data collection/aggregation. - """ - - user_id: str - listing_id: str - name: str - description: str | None = None - created_at: int - updated_at: int - order_id: str | None = None - - @classmethod - def create( - cls, - user_id: str, - listing_id: str, - name: str, - description: str | None = None, - order_id: str | None = None, - ) -> Self: - now = int(time.time()) - return cls( - id=new_uuid(), - user_id=user_id, - listing_id=listing_id, - name=name, - description=description, - created_at=now, - updated_at=now, - order_id=order_id, - ) - - -class KRec(StoreBaseModel): - """Krec recorded from robot runtime.""" - - user_id: str - robot_id: str - created_at: int - name: str - description: str | None = None - - @classmethod - def create( - cls, - user_id: str, - robot_id: str, - name: str, - description: str | None = None, - ) -> Self: - now = int(time.time()) - return cls( - id=new_uuid(), - user_id=user_id, - robot_id=robot_id, - created_at=now, - name=name, - description=description, - ) diff --git a/www/routers/__init__.py b/www/routers/__init__.py index e69de29b..971db304 100644 --- a/www/routers/__init__.py +++ b/www/routers/__init__.py @@ -0,0 +1,18 @@ +"""Defines the routers for the FastAPI app.""" + +from fastapi import Depends, FastAPI + +from www.auth import require_user + +from .auth import router as auth_router +from .robot import router as robot_router +from .robot_class import router as robot_class_router + + +def add_routers(app: FastAPI) -> None: + """Adds the routers to the FastAPI app.""" + app.include_router(auth_router, prefix="/auth", tags=["auth"]) + + # Mark the non-auth routers as protected. + app.include_router(robot_router, prefix="/robot", tags=["robot"], dependencies=[Depends(require_user)]) + app.include_router(robot_class_router, prefix="/robots", tags=["robots"], dependencies=[Depends(require_user)]) diff --git a/www/routers/auth.py b/www/routers/auth.py index 33a0ac74..c846cacb 100644 --- a/www/routers/auth.py +++ b/www/routers/auth.py @@ -56,7 +56,7 @@ async def profile( @router.get("/logout") -async def logout(request: Request, user: Annotated[UserResponse, Depends(require_user)]) -> bool: +async def logout(request: Request, user: Annotated[User, Depends(require_user)]) -> bool: request.session.clear() return True diff --git a/www/routers/robot.py b/www/routers/robot.py new file mode 100644 index 00000000..0e1027f6 --- /dev/null +++ b/www/routers/robot.py @@ -0,0 +1,153 @@ +"""Defines the API endpoint for managing robots.""" + +from typing import Annotated, Self + +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel + +from www.auth import User, require_permissions, require_user +from www.crud.robot import Robot, RobotCrud, robot_crud +from www.crud.robot_class import RobotClass, RobotClassCrud, robot_class_crud +from www.errors import ActionNotAllowedError, ItemNotFoundError +from www.routers.robot_class import get_robot_class_by_name + +router = APIRouter() + + +class RobotResponse(BaseModel): + id: str + robot_name: str + description: str + user_id: str + class_name: str + + @classmethod + def from_robot(cls, robot: Robot, robot_class: RobotClass) -> Self: + return cls( + id=robot.id, + robot_name=robot.robot_name, + description=robot.description, + user_id=robot.user_id, + class_name=robot_class.class_name, + ) + + +@router.get("/") +async def get_robots( + crud: RobotCrud = Depends(robot_crud), +) -> list[Robot]: + return await crud.list_robots() + + +async def _get_robot_and_class_by_id( + id: str, + crud: Annotated[RobotCrud, Depends(robot_crud)], + cls_crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> tuple[Robot, RobotClass]: + robot = await crud.get_robot_by_id(id) + if robot is None: + raise ItemNotFoundError(f"Robot '{id}' not found") + robot_class = await cls_crud.get_robot_class_by_id(robot.class_id) + if robot_class is None: + raise ItemNotFoundError(f"Robot class '{robot.class_id}' not found") + return robot, robot_class + + +async def _get_base_robot_by_name( + robot_name: str, + crud: Annotated[RobotCrud, Depends(robot_crud)], +) -> Robot: + robot = await crud.get_robot_by_name(robot_name) + if robot is None: + raise ItemNotFoundError(f"Robot '{robot_name}' not found") + return robot + + +async def _get_robot_and_class_by_name( + robot_name: str, + crud: Annotated[RobotCrud, Depends(robot_crud)], + cls_crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> tuple[Robot, RobotClass]: + robot = await crud.get_robot_by_name(robot_name) + if robot is None: + raise ItemNotFoundError(f"Robot '{robot_name}' not found") + robot_class = await cls_crud.get_robot_class_by_id(robot.class_id) + if robot_class is None: + raise ItemNotFoundError(f"Robot class '{robot.class_id}' not found") + return robot, robot_class + + +@router.get("/name/{robot_name}") +async def get_robot_by_name( + robot_name: str, + crud: Annotated[RobotCrud, Depends(robot_crud)], + cls_crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> RobotResponse: + robot, robot_class = await _get_robot_and_class_by_name(robot_name, crud, cls_crud) + return RobotResponse.from_robot(robot, robot_class) + + +@router.get("/id/{id}") +async def get_robot_by_id( + id: str, + crud: Annotated[RobotCrud, Depends(robot_crud)], + cls_crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> RobotResponse: + robot, robot_class = await _get_robot_and_class_by_id(id, crud, cls_crud) + return RobotResponse.from_robot(robot, robot_class) + + +@router.get("/user/{user_id}") +async def get_robots_for_user( + user_id: str, + user: Annotated[User, Depends(require_user)], + crud: Annotated[RobotCrud, Depends(robot_crud)], +) -> list[Robot]: + if user_id.lower() == "me": + return await crud.list_robots(user.id) + else: + return await crud.list_robots(user_id) + + +@router.put("/add") +async def add_robot( + robot_name: str, + user: Annotated[User, Depends(require_permissions({"upload"}))], + robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + crud: Annotated[RobotCrud, Depends(robot_crud)], +) -> RobotResponse: + robot = await crud.add_robot(robot_name, user.id, robot_class.id) + return RobotResponse.from_robot(robot, robot_class) + + +@router.put("/update") +async def update_robot( + user: Annotated[User, Depends(require_permissions({"upload"}))], + existing_robot_tuple: Annotated[tuple[Robot, RobotClass], Depends(_get_robot_and_class_by_name)], + crud: Annotated[RobotCrud, Depends(robot_crud)], + new_robot_name: str | None = Query( + default=None, + description="The new name of the robot", + ), + new_description: str | None = Query( + default=None, + description="The new description of the robot", + ), +) -> RobotResponse: + existing_robot, existing_robot_class = existing_robot_tuple + if existing_robot.user_id != user.id: + raise ActionNotAllowedError("You are not the owner of this robot") + robot = await crud.update_robot(existing_robot, new_robot_name, new_description) + return RobotResponse.from_robot(robot, existing_robot_class) + + +@router.delete("/delete") +async def delete_robot( + user: Annotated[User, Depends(require_user)], + robot: Annotated[Robot, Depends(_get_base_robot_by_name)], + crud: Annotated[RobotCrud, Depends(robot_crud)], +) -> bool: + if robot.user_id != user.id: + raise ActionNotAllowedError("You are not the owner of this robot") + await crud.delete_robot(robot) + return True diff --git a/www/routers/robot_class.py b/www/routers/robot_class.py new file mode 100644 index 00000000..c358b17d --- /dev/null +++ b/www/routers/robot_class.py @@ -0,0 +1,165 @@ +"""Defines the API endpoint for managing robot classes.""" + +import asyncio +from typing import Annotated + +from fastapi import APIRouter, Depends, Query +from pydantic import BaseModel + +from www.auth import User, require_permissions, require_user +from www.crud.base.s3 import S3Crud, s3_crud +from www.crud.robot_class import RobotClass, RobotClassCrud, robot_class_crud +from www.errors import ActionNotAllowedError, ItemNotFoundError + +router = APIRouter() + + +def urdf_s3_key(robot_class: RobotClass) -> str: + return f"urdfs/{robot_class.id}/robot.urdf" + + +@router.get("/") +async def get_robot_classes( + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> list[RobotClass]: + """Gets all robot classes.""" + return await crud.list_robot_classes() + + +@router.get("/name/{class_name}") +async def get_robot_class_by_name( + class_name: str, + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> RobotClass: + """Gets a robot class by name.""" + robot_class = await crud.get_robot_class_by_name(class_name) + if robot_class is None: + raise ItemNotFoundError(f"Robot class '{class_name}' not found") + return robot_class + + +@router.get("/user/{user_id}") +async def get_robot_classes_for_user( + user_id: str, + user: Annotated[User, Depends(require_user)], + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> list[RobotClass]: + """Gets a robot class.""" + if user_id.lower() == "me": + return await crud.list_robot_classes(user.id) + else: + return await crud.list_robot_classes(user_id) + + +@router.put("/add") +async def add_robot_class( + class_name: str, + user: Annotated[User, Depends(require_permissions({"upload"}))], + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> RobotClass: + """Adds a robot class.""" + return await crud.add_robot_class(class_name, user.id) + + +@router.put("/update") +async def update_robot_class( + user: Annotated[User, Depends(require_permissions({"upload"}))], + existing_robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], + new_class_name: str | None = Query( + default=None, + description="The new name of the robot class", + ), + new_description: str | None = Query( + default=None, + description="The new description of the robot class", + ), +) -> RobotClass: + """Updates a robot class.""" + if existing_robot_class.user_id != user.id: + raise ActionNotAllowedError("You are not the owner of this robot class") + + return await crud.update_robot_class( + robot_class=existing_robot_class, + new_class_name=new_class_name, + new_description=new_description, + ) + + +@router.delete("/delete") +async def delete_robot_class( + user: Annotated[User, Depends(require_user)], + robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + crud: Annotated[RobotClassCrud, Depends(robot_class_crud)], +) -> bool: + """Deletes a robot class.""" + if robot_class.user_id != user.id: + raise ActionNotAllowedError("You are not the owner of this robot class") + s3_key = urdf_s3_key(robot_class) + await s3_crud.delete_from_s3(s3_key) + await crud.delete_robot_class(robot_class) + return True + + +urdf_router = APIRouter() + + +class RobotDownloadURDFResponse(BaseModel): + url: str + md5_hash: str + + +@urdf_router.get("/{class_name}") +async def get_urdf_for_robot( + robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + s3_crud: Annotated[S3Crud, Depends(s3_crud)], +) -> RobotDownloadURDFResponse: + s3_key = urdf_s3_key(robot_class) + url, md5_hash = await asyncio.gather( + s3_crud.generate_presigned_download_url(s3_key), + s3_crud.get_file_hash(s3_key), + ) + return RobotDownloadURDFResponse(url=url, md5_hash=md5_hash) + + +class RobotUploadURDFResponse(BaseModel): + url: str + filename: str + content_type: str + + +@urdf_router.put("/{class_name}") +async def upload_urdf_for_robot( + filename: str, + content_type: str, + robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + s3_crud: Annotated[S3Crud, Depends(s3_crud)], +) -> RobotUploadURDFResponse: + # Checks that the content type is valid. + if content_type not in { + "application/octet-stream", + "application/xml", + "application/gzip", + "application/x-gzip", + "application/x-tar", + "application/x-compressed-tar", + "application/zip", + }: + raise ValueError(f"Invalid content type: {content_type}") + + s3_key = urdf_s3_key(robot_class) + url = await s3_crud.generate_presigned_upload_url(filename, s3_key, content_type) + return RobotUploadURDFResponse(url=url, filename=filename, content_type=content_type) + + +@urdf_router.delete("/{class_name}") +async def delete_urdf_for_robot( + robot_class: Annotated[RobotClass, Depends(get_robot_class_by_name)], + s3_crud: Annotated[S3Crud, Depends(s3_crud)], +) -> bool: + s3_key = urdf_s3_key(robot_class) + await s3_crud.delete_from_s3(s3_key) + return True + + +router.include_router(urdf_router, prefix="/urdf") diff --git a/www/settings/configs/local.yaml b/www/settings/configs/local.yaml index 3a5abd77..c4979c80 100644 --- a/www/settings/configs/local.yaml +++ b/www/settings/configs/local.yaml @@ -4,8 +4,6 @@ aws: s3: bucket: artifacts prefix: media/ - dynamodb: - table_suffix: local cloudfront: domain: ${site.artifact_base_url} site: diff --git a/www/settings/environment.py b/www/settings/environment.py index 717c8f50..98b05d8e 100644 --- a/www/settings/environment.py +++ b/www/settings/environment.py @@ -44,15 +44,15 @@ class ArtifactSettings: @dataclass -class S3Settings: - bucket: str = field(default=SI("www-${environment}")) - prefix: str = field(default="media") +class DynamoSettings: + table_prefix: str = field(default=SI("www-${environment}")) + deletion_protection: bool = field(default=False) @dataclass -class DynamoSettings: - table_suffix: str = field(default=SI("www-${environment}")) - deletion_protection: bool = field(default=False) +class S3Settings: + bucket: str = field(default=SI("kscale-www-${environment}")) + prefix: str = field(default="") @dataclass @@ -64,8 +64,8 @@ class CloudFrontSettings: @dataclass class AwsSettings: - s3: S3Settings = field(default_factory=S3Settings) dynamodb: DynamoSettings = field(default_factory=DynamoSettings) + s3: S3Settings = field(default_factory=S3Settings) cloudfront: CloudFrontSettings = field(default_factory=CloudFrontSettings) diff --git a/www/utils/db.py b/www/utils/db.py index a7317669..90f86dba 100644 --- a/www/utils/db.py +++ b/www/utils/db.py @@ -2,7 +2,9 @@ import datetime import hashlib +import re import uuid +from dataclasses import dataclass def server_time() -> datetime.datetime: @@ -17,3 +19,35 @@ def new_uuid() -> str: SHA-256 hash of a UUID4 value. """ return hashlib.sha256(str(uuid.uuid4()).encode()).hexdigest()[:16] + + +@dataclass +class VersionNumber: + major: int + minor: int + patch: int + + def __str__(self) -> str: + return f"{self.major}.{self.minor}.{self.patch}" + + def __lt__(self, other: "VersionNumber") -> bool: + return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch) + + def __le__(self, other: "VersionNumber") -> bool: + return (self.major, self.minor, self.patch) <= (other.major, other.minor, other.patch) + + def __gt__(self, other: "VersionNumber") -> bool: + return (self.major, self.minor, self.patch) > (other.major, other.minor, other.patch) + + def __ge__(self, other: "VersionNumber") -> bool: + return (self.major, self.minor, self.patch) >= (other.major, other.minor, other.patch) + + def __repr__(self) -> str: + return f"VersionNumber(major={self.major}, minor={self.minor}, patch={self.patch})" + + @classmethod + def from_str(cls, version: str) -> "VersionNumber": + match = re.match(r"(\d+)\.(\d+)\.(\d+)", version) + if match is None: + raise ValueError(f"Invalid version number: {version}") + return cls(major=int(match.group(1)), minor=int(match.group(2)), patch=int(match.group(3)))