Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
lloesche committed Sep 19, 2024
1 parent 859ef6a commit 0ba4460
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 67 deletions.
1 change: 0 additions & 1 deletion fixattiosync/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .fixdata import FixData, add_args as fixdata_add_args
from .attiodata import AttioData, add_args as attio_add_args
from .sync import sync_fix_to_attio
from pprint import pprint


def main() -> None:
Expand Down
53 changes: 29 additions & 24 deletions fixattiosync/attiodata.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,67 @@
import os
import requests
from uuid import UUID
from typing import Union
from typing import Union, Any, Optional
from argparse import ArgumentParser
from .logger import log
from .attioresources import AttioWorkspace, AttioPerson, AttioUser


class AttioData:
def __init__(self, api_key, default_limit=500):
def __init__(self, api_key: str, default_limit: int = 500):
self.api_key = api_key
self.base_url = "https://api.attio.com/v2/"
self.default_limit = default_limit
self.hydrated = False
self.__workspaces = {}
self.__people = {}
self.__users = {}
self.__workspaces: dict[UUID, AttioWorkspace] = {}
self.__people: dict[UUID, AttioPerson] = {}
self.__users: dict[UUID, AttioUser] = {}

def _headers(self):
def _headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}

def _post_data(self, endpoint, json=None, params=None):
def _post_data(
self, endpoint: str, json: Optional[dict[str, Any]] = None, params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
log.debug(f"Fetching data from {endpoint}")
url = self.base_url + endpoint
response = requests.post(url, headers=self._headers(), json=json, params=params)
if response.status_code == 200:
return response.json()
return response.json() # type: ignore
else:
raise Exception(f"Error fetching data from {url}: {response.status_code} {response.text}")

def _put_data(self, endpoint: str, json: dict = None, params: dict = None):
def _put_data(
self, endpoint: str, json: Optional[dict[str, Any]] = None, params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
log.debug(f"Putting data to {endpoint}")
url = self.base_url + endpoint
response = requests.put(url, headers=self._headers(), json=json, params=params)
if response.status_code == 200:
return response.json()
return response.json() # type: ignore
else:
raise Exception(f"Error putting data to {url}: {response.status_code} {response.text}")

def assert_record(
self, object_id: str, matching_attribute: str, data: dict
self, object_id: str, matching_attribute: str, data: dict[str, Any]
) -> Union[AttioPerson, AttioUser, AttioWorkspace]:
endpoint = f"objects/{object_id}/records"
params = {"matching_attribute": matching_attribute}
attio_cls: Union[type[AttioPerson], type[AttioUser], type[AttioWorkspace]]
match object_id:
case "users":
attio_cls = AttioUser
self_store = self.__users
case "people":
attio_cls = AttioPerson
self_store = self.__people
self_store = self.__people # type: ignore
case "workspaces":
attio_cls = AttioWorkspace
self_store = self.__workspaces
self_store = self.__workspaces # type: ignore
case _:
raise ValueError(f"Unknown object_id: {object_id}")

Expand All @@ -65,12 +70,12 @@ def assert_record(
if response.get("data", []):
attio_obj = attio_cls.make(response["data"])
log.debug(f"Asserted {object_id} {attio_obj} in Attio, updating locally")
self_store[attio_obj.record_id] = attio_obj
self_store[attio_obj.record_id] = attio_obj # type: ignore
return attio_obj
else:
raise RuntimeError(f"Error asserting {object_id} in Attio: {response}")

def _records(self, object_id: str):
def _records(self, object_id: str) -> list[dict[str, Any]]:
log.debug(f"Fetching {object_id}")
endpoint = f"objects/{object_id}/records/query"
all_data = []
Expand All @@ -90,32 +95,32 @@ def _records(self, object_id: str):
return all_data

@property
def workspaces(self):
def workspaces(self) -> list[AttioWorkspace]:
if not self.hydrated:
self.hydrate()
return list(self.__workspaces.values())

@property
def people(self):
def people(self) -> list[AttioPerson]:
if not self.hydrated:
self.hydrate()
return list(self.__people.values())

@property
def users(self):
def users(self) -> list[AttioUser]:
if not self.hydrated:
self.hydrate()
return list(self.__users.values())

def hydrate(self):
def hydrate(self) -> None:
log.debug("Hydrating Attio data")
self.__workspaces = self.__marshal(self._records("workspaces"), AttioWorkspace)
self.__people = self.__marshal(self._records("people"), AttioPerson)
self.__users = self.__marshal(self._records("users"), AttioUser)
self.__workspaces = self.__marshal(self._records("workspaces"), AttioWorkspace) # type: ignore
self.__people = self.__marshal(self._records("people"), AttioPerson) # type: ignore
self.__users = self.__marshal(self._records("users"), AttioUser) # type: ignore
self.__connect()
self.hydrated = True

def __connect(self):
def __connect(self) -> None:
for user in self.__users.values():
if user.person_id in self.__people:
person = self.__people[user.person_id]
Expand All @@ -129,7 +134,7 @@ def __connect(self):
user.workspaces.append(workspace)

def __marshal(
self, data: dict, cls: Union[AttioWorkspace, AttioPerson, AttioUser]
self, data: list[dict[str, Any]], cls: Union[type[AttioWorkspace], type[AttioPerson], type[AttioUser]]
) -> dict[UUID, Union[AttioWorkspace, AttioPerson, AttioUser]]:
ret = {}
for item in data:
Expand Down
36 changes: 20 additions & 16 deletions fixattiosync/attioresources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import dataclass, field
from datetime import datetime
from uuid import UUID
from typing import Optional, Self, Type, ClassVar
from typing import Optional, Self, Type, ClassVar, Any
from .logger import log


def get_latest_value(value: list[dict]) -> dict:
def get_latest_value(value: list[dict[str, Any]]) -> dict[str, Any]:
if value and len(value) > 0:
return value[0]
return {}
Expand All @@ -33,7 +33,7 @@ class AttioResource(ABC):

@classmethod
@abstractmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
pass


Expand All @@ -48,11 +48,13 @@ class AttioWorkspace(AttioResource):
fix_workspace_id: Optional[UUID]
users: list[AttioUser] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.tier == other.tier
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "tier"):
return False
return bool(self.id == other.id and self.tier == other.tier)

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand All @@ -70,7 +72,7 @@ def make(cls: Type[Self], data: dict) -> Self:
status = status_info.get("status", {}).get("title")

fix_workspace_id_info = get_latest_value(values.get("workspace_id", [{}]))
fix_workspace_id = optional_uuid(fix_workspace_id_info.get("value"))
fix_workspace_id = optional_uuid(str(fix_workspace_id_info.get("value")))
if fix_workspace_id is None:
log.error(f"Fix workspace ID not found for {record_id}: {data}")

Expand Down Expand Up @@ -102,7 +104,7 @@ class AttioPerson(AttioResource):
users: list[AttioUser] = field(default_factory=list)

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand Down Expand Up @@ -137,7 +139,7 @@ def make(cls: Type[Self], data: dict) -> Self:
"linkedin": linkedin,
}

return cls(**cls_data)
return cls(**cls_data) # type: ignore


@dataclass
Expand All @@ -155,11 +157,13 @@ class AttioUser(AttioResource):
person: Optional[AttioPerson] = None
workspaces: list[AttioWorkspace] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.email.lower() == other.email.lower()
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "email"):
return False
return bool(self.id == other.id and str(self.email).lower() == str(other.email).lower())

@classmethod
def make(cls: Type[Self], data: dict) -> Self:
def make(cls: Type[Self], data: dict[str, Any]) -> Self:
object_id = UUID(data["id"]["object_id"])
record_id = UUID(data["id"]["record_id"])
workspace_id = UUID(data["id"]["workspace_id"])
Expand All @@ -179,21 +183,21 @@ def make(cls: Type[Self], data: dict) -> Self:
log.error(f"Fix user ID not found for {record_id}: {data}")

person_info = get_latest_value(values.get("person", [{}]))
person_id = optional_uuid(person_info.get("target_record_id"))
person_id = optional_uuid(str(person_info.get("target_record_id")))

workspace_refs = None
workspace_info = values.get("workspace", [])
for workspace in workspace_info:
workspace_ref = optional_uuid(workspace.get("target_record_id"))
workspace_ref = optional_uuid(str(workspace.get("target_record_id")))
if workspace_refs is None:
workspace_refs = []
workspace_refs.append(workspace_ref)

cls_data = {
"id": user_id,
"object_id": object_id,
"record_id": record_id,
"workspace_id": workspace_id,
"id": user_id,
"created_at": created_at,
"demo_workspace_viewed": None,
"email": primary_email_address,
Expand All @@ -206,7 +210,7 @@ def make(cls: Type[Self], data: dict) -> Self:

return cls(**cls_data)

def create_or_update(self) -> tuple[str, dict]:
def create_or_update(self) -> tuple[str, dict[str, Any]]:
data = {
"values": {
"primary_email_address": [{"email_address": self.email}],
Expand Down
16 changes: 9 additions & 7 deletions fixattiosync/fixdata.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import os
import psycopg
from psycopg.rows import dict_row
from uuid import UUID
from argparse import ArgumentParser
from .logger import log
from .fixresources import FixUser, FixWorkspace
from typing import Optional


class FixData:
def __init__(self, db, user, password, host="localhost", port=5432):
def __init__(self, db: str, user: str, password: str, host: str = "localhost", port: int = 5432) -> None:
self.db = db
self.user = user
self.password = password
self.host = host
self.port = port
self.conn = None
self.conn: Optional[psycopg.Connection] = None
self.hydrated = False
self.__workspaces = {}
self.__users = {}
self.__workspaces: dict[UUID, FixWorkspace] = {}
self.__users: dict[UUID, FixUser] = {}

@property
def users(self) -> list[FixUser]:
Expand All @@ -30,7 +32,7 @@ def workspaces(self) -> list[FixWorkspace]:
self.hydrate()
return list(self.__workspaces.values())

def connect(self):
def connect(self) -> None:
log.debug("Connecting to the database")
if self.conn is None:
try:
Expand All @@ -42,7 +44,7 @@ def connect(self):
log.error(f"Error connecting to the database: {e}")
self.conn = None

def hydrate(self):
def hydrate(self) -> None:
if self.conn is None:
self.connect()

Expand Down Expand Up @@ -81,7 +83,7 @@ def hydrate(self):
log.debug(f"Found {len(self.__users)} users in database")
self.hydrated = True

def close(self):
def close(self) -> None:
if self.conn is not None:
log.debug("Closing database connection")
self.conn.close()
Expand Down
26 changes: 15 additions & 11 deletions fixattiosync/fixresources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from uuid import UUID
from typing import Optional
from typing import Optional, Self, Any
from .attioresources import AttioPerson, AttioWorkspace


Expand All @@ -20,15 +20,17 @@ class FixUser:
updated_at: datetime
workspaces: list[FixWorkspace] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.email.lower() == other.email.lower()
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "email"):
return False
return bool(self.id == other.id and str(self.email).lower() == str(other.email).lower())

def attio_data(
self, person: Optional[AttioPerson] = None, workspaces: Optional[list[AttioWorkspace]] = None
) -> dict:
) -> dict[str, Any]:
object_id = "users"
matching_attribute = "user_id"
data = {
data: dict[str, Any] = {
"data": {
"values": {
"user_id": str(self.id),
Expand Down Expand Up @@ -58,10 +60,10 @@ def attio_data(
"data": data,
}

def attio_person(self) -> dict:
def attio_person(self) -> dict[str, Any]:
object_id = "people"
matching_attribute = "email_addresses"
data = {"data": {"values": {"email_addresses": [{"email_address": self.email}]}}}
data: dict[str, Any] = {"data": {"values": {"email_addresses": [{"email_address": self.email}]}}}
return {
"object_id": object_id,
"matching_attribute": matching_attribute,
Expand All @@ -87,13 +89,15 @@ class FixWorkspace:
owner: Optional[FixUser] = None
users: list[FixUser] = field(default_factory=list)

def __eq__(self, other):
return self.id == other.id and self.name == other.name and self.tier == other.tier
def __eq__(self: Self, other: Any) -> bool:
if not hasattr(other, "id") or not hasattr(other, "name") or not hasattr(other, "tier"):
return False
return bool(self.id == other.id and self.name == other.name and self.tier == other.tier)

def attio_data(self) -> dict:
def attio_data(self) -> dict[str, Any]:
object_id = "workspaces"
matching_attribute = "workspace_id"
data = {
data: dict[str, Any] = {
"data": {
"values": {
"workspace_id": str(self.id),
Expand Down
Loading

0 comments on commit 0ba4460

Please sign in to comment.