diff --git a/inventory_foundation_sdk/custom_datasets.py b/inventory_foundation_sdk/custom_datasets.py index 1c369df..1c39816 100644 --- a/inventory_foundation_sdk/custom_datasets.py +++ b/inventory_foundation_sdk/custom_datasets.py @@ -25,13 +25,14 @@ # %% ../nbs/50_custom_datasets.ipynb 5 import logging + logger = logging.getLogger(__name__) + class AddRowDataset(AbstractDataset): - """ Adds or update one row to a SQL table, if it does not exist. - + """ def __init__( @@ -40,8 +41,8 @@ def __init__( column_names: t.List, credentials: str, unique_columns: t.List, - load_args = None, - save_args = None + load_args=None, + save_args=None, ): self.unique_columns = unique_columns @@ -50,7 +51,7 @@ def __init__( self.db_credentials = credentials self.save_args = save_args or {} self.load_args = load_args or {} - + def _describe(self) -> t.Dict[str, t.Any]: """Returns a dict that describes the attributes of the dataset.""" return dict( @@ -66,15 +67,13 @@ def _load(self) -> pd.DataFrame: return_all_columns = self.load_args.get("return_all_columns", False) try: - with psycopg2.connect(self.db_credentials['con']) as conn: + with psycopg2.connect(self.db_credentials["con"]) as conn: with conn.cursor() as cursor: - + if return_all_columns: # Fetch all rows - cursor.execute( - f"SELECT * FROM {self.table}" - ) + cursor.execute(f"SELECT * FROM {self.table}") data = cursor.fetchall() # Fetch column names in the correct order from the database @@ -87,7 +86,7 @@ def _load(self) -> pd.DataFrame: WHERE c.relname = %s AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """, - (self.table,) + (self.table,), ) columns = [row[0] for row in cursor.fetchall()] @@ -116,22 +115,28 @@ def _save(self, data: pd.DataFrame) -> None: """ verbose = self.save_args.get("verbose", 1) - + try: # Connect to the database - with psycopg2.connect(self.db_credentials['con']) as conn: + with psycopg2.connect(self.db_credentials["con"]) as conn: with conn.cursor() as cursor: # Prepare data insertion for _, row in data.iterrows(): # Ensure all data is properly converted to standard Python types row_data = tuple( - row[col].item() if isinstance(row[col], (np.generic, np.ndarray)) else row[col] + ( + row[col].item() + if isinstance(row[col], (np.generic, np.ndarray)) + else row[col] + ) for col in self.column_names ) # Determine the update clause (exclude unique columns) updatable_columns = [ - col for col in self.column_names if col not in self.unique_columns + col + for col in self.column_names + if col not in self.unique_columns ] # Only create an update clause if there are columns to update @@ -149,30 +154,42 @@ def _save(self, data: pd.DataFrame) -> None: ) # Build the SQL query dynamically - query = sql.SQL(""" + query = sql.SQL( + """ INSERT INTO {table} ({columns}) VALUES ({values}) ON CONFLICT ({conflict_clause}) DO UPDATE SET {update_clause} RETURNING xmax = 0 AS is_inserted - """).format( + """ + ).format( table=sql.Identifier(self.table), - columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.column_names), - values=sql.SQL(", ").join(sql.Placeholder() for _ in self.column_names), + columns=sql.SQL(", ").join( + sql.Identifier(col) for col in self.column_names + ), + values=sql.SQL(", ").join( + sql.Placeholder() for _ in self.column_names + ), conflict_clause=conflict_clause, - update_clause=update_clause + update_clause=update_clause, ) else: # Build the SQL query for insertion without an update clause - query = sql.SQL(""" + query = sql.SQL( + """ INSERT INTO {table} ({columns}) VALUES ({values}) ON CONFLICT DO NOTHING RETURNING xmax = 0 AS is_inserted - """).format( + """ + ).format( table=sql.Identifier(self.table), - columns=sql.SQL(", ").join(sql.Identifier(col) for col in self.column_names), - values=sql.SQL(", ").join(sql.Placeholder() for _ in self.column_names) + columns=sql.SQL(", ").join( + sql.Identifier(col) for col in self.column_names + ), + values=sql.SQL(", ").join( + sql.Placeholder() for _ in self.column_names + ), ) # Execute the query with properly cast values @@ -183,9 +200,13 @@ def _save(self, data: pd.DataFrame) -> None: if verbose > 0: if is_inserted: - logger.info(f"Inserted new row: {dict(zip(self.column_names, row_data))}") + logger.info( + f"Inserted new row: {dict(zip(self.column_names, row_data))}" + ) else: - logger.info(f"Updated row (or skipped due to conflict): {dict(zip(self.column_names, row_data))}") + logger.info( + f"Updated row (or skipped due to conflict): {dict(zip(self.column_names, row_data))}" + ) # Commit the transaction conn.commit() @@ -196,16 +217,14 @@ def _save(self, data: pd.DataFrame) -> None: # %% ../nbs/50_custom_datasets.ipynb 6 class DynamicPathJSONDataset(AbstractDataset): - """ Custom dataset to dynamically resolve a JSON file path from parameters. """ def __init__(self, path_param: str): - """ Initializes the ConditionedJSONDataset. - + Args: path_param (str): The parameter key that contains the file path. """ @@ -220,10 +239,12 @@ def _load(self) -> dict: """ # Load parameters params_path = self.config_loader["parameters"][self.path_param] - + # Resolve the file path from parameters if not params_path: - raise ValueError(f"Path parameter '{self.path_param}' not found in parameters.") + raise ValueError( + f"Path parameter '{self.path_param}' not found in parameters." + ) # Load and return JSON data full_path = Path(params_path) diff --git a/inventory_foundation_sdk/db_mgmt.py b/inventory_foundation_sdk/db_mgmt.py index 5ae99e2..fae0f37 100644 --- a/inventory_foundation_sdk/db_mgmt.py +++ b/inventory_foundation_sdk/db_mgmt.py @@ -14,10 +14,8 @@ import numpy as np from tqdm import tqdm - # %% ../nbs/10_db_mgmt.ipynb 5 def get_db_credentials(): - """ Fetch PostgreSQL database credentials from the configuration file of the kedro project. @@ -35,8 +33,10 @@ def get_db_credentials(): # %% ../nbs/10_db_mgmt.ipynb 6 import logging + logger = logging.getLogger(__name__) + def insert_multi_rows( data_to_insert: pd.DataFrame, table_name: str, @@ -47,7 +47,6 @@ def insert_multi_rows( return_with_ids: bool = False, unique_columns: list = None, # mandatory if return_with_ids is True ) -> pd.DataFrame | None: - """ Inserts data into the specified database table, with an optional return of database-assigned IDs. @@ -68,31 +67,40 @@ def insert_multi_rows( # Check for NaN values and log a warning if any are found if data_to_insert.isnull().values.any(): logger.warning("There are NaNs in the data") - + # Ensure the DataFrame has the correct number of columns if len(column_names) != data_to_insert.shape[1]: - raise ValueError("Number of column names does not match the number of columns in the DataFrame.") + raise ValueError( + "Number of column names does not match the number of columns in the DataFrame." + ) if len(types) != data_to_insert.shape[1]: - raise ValueError("Number of types does not match the number of columns in the DataFrame.") - + raise ValueError( + "Number of types does not match the number of columns in the DataFrame." + ) + logger.info("-- in insert multi rows -- converting data to list of tuples") # Convert to list of tuples and apply type casting data_values = data_to_insert.values.tolist() - data_values = [tuple(typ(val) for typ, val in zip(types, row)) for row in data_values] - + data_values = [ + tuple(typ(val) for typ, val in zip(types, row)) for row in data_values + ] + logger.info("-- in insert multi rows -- preparing SQL") # Create SQL placeholders and query placeholders = ", ".join(["%s"] * len(column_names)) column_names_str = ", ".join(f'"{col}"' for col in column_names) - - batch_size_for_commit = 1_000_000 # Adjust this based on your dataset size and transaction tolerance + batch_size_for_commit = ( + 1_000_000 # Adjust this based on your dataset size and transaction tolerance + ) row_count = 0 if return_with_ids: if not unique_columns: - raise ValueError("unique_columns must be provided when return_with_ids is True") + raise ValueError( + "unique_columns must be provided when return_with_ids is True" + ) unique_columns_str = ", ".join(f'"{col}"' for col in unique_columns) insert_query = f""" @@ -104,8 +112,6 @@ def insert_multi_rows( """ ids = [] - - # Insert row by row and collect IDs with tqdm(total=len(data_values), desc="Inserting rows") as pbar: for row in data_values: @@ -115,12 +121,12 @@ def insert_multi_rows( ids.append(row_id[0]) row_count += 1 pbar.update(1) # Update progress bar for each row - + # Commit every batch_size_for_commit rows if row_count % batch_size_for_commit == 0: conn.commit() # Commit the transaction - conn.commit() - + conn.commit() + # Add IDs back to the original DataFrame data_with_ids = data_to_insert.copy() data_with_ids["ID"] = ids @@ -132,7 +138,7 @@ def insert_multi_rows( VALUES ({placeholders}) ON CONFLICT DO NOTHING; """ - + # Insert row by row without returning IDs with tqdm(total=len(data_values), desc="Inserting rows") as pbar: for row in data_values: @@ -141,7 +147,7 @@ def insert_multi_rows( pbar.update(1) # Update progress bar for each row if row_count % batch_size_for_commit == 0: conn.commit() # Commit the transaction - + conn.commit() # Commit all changes after processing return None diff --git a/inventory_foundation_sdk/etl_db_writers.py b/inventory_foundation_sdk/etl_db_writers.py index 0251427..ea89037 100644 --- a/inventory_foundation_sdk/etl_db_writers.py +++ b/inventory_foundation_sdk/etl_db_writers.py @@ -24,18 +24,21 @@ # %% ../nbs/30_ETL_db_writers.ipynb 5 import logging + logger = logging.getLogger(__name__) -def write_company_name(name: str, additional_info: t.Dict = None, ignore_company_if_exist: bool = True) -> int: - + +def write_company_name( + name: str, additional_info: t.Dict = None, ignore_company_if_exist: bool = True +) -> int: """ This function writes the company name to the database and any additional info. Each key in `additional_info` becomes a column in the database table if it doesn't exist, and the associated value is written to that column. - + If `ignore_company_if_exist` is False and the company name already exists, an error is raised. If `ignore_company_if_exist` is True, a warning is logged and the existing record is updated if additional info differs. - + Returns the ID that the database has assigned to the company name. """ @@ -52,23 +55,23 @@ def write_company_name(name: str, additional_info: t.Dict = None, ignore_company ON CONFLICT (name) DO NOTHING RETURNING "ID"; """, - (name,) + (name,), ) result = cur.fetchone() - + if result is None: # Company exists, handle based on ignore_company_if_exist flag cur.execute( """ SELECT "ID" FROM companies WHERE name = %s; """, - (name,) + (name,), ) company_id = cur.fetchone()[0] - + if not ignore_company_if_exist: raise ValueError(f"Company '{name}' already exists.") - + logger.warning("Company already exists, ignoring new entry") # Check if additional info needs to be updated @@ -81,33 +84,35 @@ def write_company_name(name: str, additional_info: t.Dict = None, ignore_company ADD COLUMN IF NOT EXISTS {key} TEXT; """ ) - + # Check current value before updating cur.execute( f""" SELECT {key} FROM companies WHERE "ID" = %s; """, - (company_id,) + (company_id,), ) current_value = cur.fetchone()[0] - + # Only update if the value is different if current_value != value: - logger.warning(f"Overwriting '{key}' for company '{name}' from '{current_value}' to '{value}'.") + logger.warning( + f"Overwriting '{key}' for company '{name}' from '{current_value}' to '{value}'." + ) cur.execute( f""" UPDATE companies SET {key} = %s WHERE "ID" = %s; """, - (value, company_id) + (value, company_id), ) else: company_id = result[0] - + # Insert additional information for new entry if additional_info is not None: - for key, value in additional_info.to_dict().items(): + for key, value in additional_info.items(): # Add column if it doesn't exist cur.execute( f""" @@ -122,9 +127,9 @@ def write_company_name(name: str, additional_info: t.Dict = None, ignore_company SET {key} = %s WHERE "ID" = %s; """, - (value, company_id) + (value, company_id), ) - + except Exception as e: logger.error(f"Got error while writing company name to database: {e}") raise e @@ -132,8 +137,9 @@ def write_company_name(name: str, additional_info: t.Dict = None, ignore_company return company_id # %% ../nbs/30_ETL_db_writers.ipynb 7 -def write_categories(categories: dict, company_id: int, category_level_names: list = None) -> t.List[int]: - +def write_categories( + categories: dict, company_id: int, category_level_names: list = None +) -> t.List[int]: """ This function writes the categories to the database. @@ -162,11 +168,9 @@ def write_categories(categories: dict, company_id: int, category_level_names: li write_category_level(level, company_id) return True - # %% ../nbs/30_ETL_db_writers.ipynb 9 def write_category_level_descriptions(category_level_names, company_id): - """ This function writes the names of the category levels to the database. """ @@ -184,14 +188,14 @@ def write_category_level_descriptions(category_level_names, company_id): ON CONFLICT ("companyID", level) DO UPDATE SET name = EXCLUDED.name; """, - (company_id, i+1, name) + (company_id, i + 1, name), ) except Exception as e: logger.error(f"Got error while writing category level names to database: {e}") raise e + def write_category_level(categories: list, company_id: int) -> t.Dict[str, int]: - """ This function writes one level of categories into the database and returns a list of the IDs that the database has assigned. The purpose is to call this function repeatedly for each level of categories. @@ -223,19 +227,19 @@ def write_category_level(categories: list, company_id: int) -> t.Dict[str, int]: WHERE "companyID" = %s AND name = %s LIMIT 1; """, - (company_id, key, company_id, key) + (company_id, key, company_id, key), ) category_id = cur.fetchone()[0] - + if parents is not None: - + for parent in parents: cur.execute( """ SELECT "ID" FROM categories WHERE "companyID" = %s AND name = %s; """, - (company_id, parent) + (company_id, parent), ) parent_id = cur.fetchone()[0] cur.execute( @@ -245,7 +249,7 @@ def write_category_level(categories: list, company_id: int) -> t.Dict[str, int]: ON CONFLICT ("subID", "parentID") DO NOTHING; """, - (category_id, parent_id) + (category_id, parent_id), ) return @@ -255,7 +259,6 @@ def write_category_level(categories: list, company_id: int) -> t.Dict[str, int]: # %% ../nbs/30_ETL_db_writers.ipynb 11 def write_products(products: pd.DataFrame, company_id: int) -> None: - """ This function writes the products to the database. @@ -263,7 +266,7 @@ def write_products(products: pd.DataFrame, company_id: int) -> None: First column: product name (column name is irrelevant) Second column: category name (column name is irrelevant) - Note that each product may have more than one category. + Note that each product may have more than one category. """ @@ -281,17 +284,17 @@ def write_products(products: pd.DataFrame, company_id: int) -> None: JOIN products ON product_categories."productID" = products."ID" ; """, - (company_id,) + (company_id,), ) names = cur.fetchall() names = [name[0] for name in names] - - products_filtered = products[~products.iloc[:,0].isin(names)] - products_filtered_list = products_filtered.iloc[:,0].tolist() + products_filtered = products[~products.iloc[:, 0].isin(names)] + + products_filtered_list = products_filtered.iloc[:, 0].tolist() products_filtered_list = list(set(products_filtered_list)) - + inserted_ids = [] for product in products_filtered_list: cur.execute( @@ -300,7 +303,7 @@ def write_products(products: pd.DataFrame, company_id: int) -> None: VALUES (%s) RETURNING "ID"; """, - (product,) + (product,), ) inserted_id = cur.fetchone()[0] # Fetch the generated ID inserted_ids.append(inserted_id) @@ -312,33 +315,42 @@ def write_products(products: pd.DataFrame, company_id: int) -> None: JOIN categories ON companies."ID" = categories."companyID" WHERE companies."ID" = %s; """, - (company_id,) + (company_id,), ) category_names = cur.fetchall() category_names_df = pd.DataFrame(category_names, columns=["ID", "name"]) - products_filtered = products_filtered.merge(category_names_df, left_on=products_filtered.columns[1], right_on="name", how="left") + products_filtered = products_filtered.merge( + category_names_df, + left_on=products_filtered.columns[1], + right_on="name", + how="left", + ) products_filtered["product_id"] = inserted_ids products_filtered = products_filtered[["product_id", "ID"]] - products_filtered["product_id"] = products_filtered["product_id"].astype(int) + products_filtered["product_id"] = products_filtered[ + "product_id" + ].astype(int) products_filtered["ID"] = products_filtered["ID"].astype(int) - values_to_insert = [tuple(row) for row in products_filtered.itertuples(index=False)] + values_to_insert = [ + tuple(row) for row in products_filtered.itertuples(index=False) + ] cur.executemany( """ INSERT INTO product_categories ("productID", "categoryID") VALUES (%s, %s); """, - values_to_insert # Use the converted list of tuples + values_to_insert, # Use the converted list of tuples ) - + except Exception as e: logger.error(f"Got error while writing products to database: {e}") - raise e + raise e # %% ../nbs/30_ETL_db_writers.ipynb 13 def write_stores(store_regions: pd.DataFrame, company_id) -> None: @@ -348,7 +360,6 @@ def write_stores(store_regions: pd.DataFrame, company_id) -> None: db_credentials = get_db_credentials()["con"] - try: with psycopg2.connect(db_credentials) as conn: with conn.cursor() as cur: @@ -356,21 +367,22 @@ def write_stores(store_regions: pd.DataFrame, company_id) -> None: store_regions = add_region_ids(store_regions, cur) cur.executemany( - """ INSERT INTO stores ("name", "regionID", "companyID") VALUES (%s, %s, %s) ON CONFLICT ("name", "companyID") DO NOTHING; """, - (store_regions[["name", "region_id"]].assign(companyID=company_id).values.tolist()) + ( + store_regions[["name", "region_id"]] + .assign(companyID=company_id) + .values.tolist() + ), ) - except Exception as e: logger.error(f"Got error while writing stores to database: {e}") raise e - # %% ../nbs/30_ETL_db_writers.ipynb 15 def get_region_ids(cur, country, abbreviation, type_): @@ -394,12 +406,13 @@ def get_region_ids(cur, country, abbreviation, type_): FROM RegionHierarchy WHERE "abbreviation" = %s AND "type" = %s; """, - (country, 'country', abbreviation, type_) + (country, "country", abbreviation, type_), ) region_id = cur.fetchone() return region_id + def add_region_ids(data, cur): """ Adds region IDs to the given DataFrame by mapping region, type, and country. @@ -430,12 +443,13 @@ def add_region_ids(data, cur): # Convert mapping to a DataFrame region_id_mapping_df = pd.DataFrame( - region_id_mapping, - columns=["region", "type", "country", "region_id"] + region_id_mapping, columns=["region", "type", "country", "region_id"] ) # Merge the region ID mapping back into the original data - data = data.merge(region_id_mapping_df, on=["region", "type", "country"], how="left") + data = data.merge( + region_id_mapping_df, on=["region", "type", "country"], how="left" + ) # Check for any unmatched rows (this should not happen due to the error raised earlier) if data["region_id"].isnull().any(): @@ -467,35 +481,45 @@ def write_skus(store_item_combinations: pd.DataFrame, company_id: int) -> None: with conn.cursor() as cur: # Fetch product IDs product_mapping = get_product_ids( - cur=cur, - company_id=company_id, - item_name_list=store_item_combinations["item_name"].unique().tolist() + cur=cur, + company_id=company_id, + item_name_list=store_item_combinations["item_name"] + .unique() + .tolist(), ) # Fetch store IDs store_mapping = get_store_ids( - cur=cur, - company_id=company_id, - store_name_list=store_item_combinations["store_name"].unique().tolist() + cur=cur, + company_id=company_id, + store_name_list=store_item_combinations["store_name"] + .unique() + .tolist(), ) # Merge product and store IDs with the input DataFrame merged_data = store_item_combinations.merge( product_mapping, on="item_name", how="left" - ).merge( - store_mapping, on="store_name", how="left" - ) + ).merge(store_mapping, on="store_name", how="left") # Check for unmatched rows if merged_data["productID"].isnull().any(): - unmatched_items = merged_data.loc[merged_data["productID"].isnull(), "item_name"].unique() + unmatched_items = merged_data.loc[ + merged_data["productID"].isnull(), "item_name" + ].unique() raise ValueError(f"Unmatched item_names: {unmatched_items}") if merged_data["storeID"].isnull().any(): - unmatched_stores = merged_data.loc[merged_data["storeID"].isnull(), "store_name"].unique() + unmatched_stores = merged_data.loc[ + merged_data["storeID"].isnull(), "store_name" + ].unique() raise ValueError(f"Unmatched store_names: {unmatched_stores}") # Prepare data for insertion - sku_data = merged_data[["productID", "storeID"]].drop_duplicates().values.tolist() + sku_data = ( + merged_data[["productID", "storeID"]] + .drop_duplicates() + .values.tolist() + ) # Insert data into the SKU table cur.executemany( @@ -504,7 +528,7 @@ def write_skus(store_item_combinations: pd.DataFrame, company_id: int) -> None: VALUES (%s, %s) ON CONFLICT ("productID", "storeID") DO NOTHING; """, - sku_data + sku_data, ) conn.commit() @@ -533,11 +557,12 @@ def get_product_ids(cur, company_id, item_name_list): JOIN categories ON product_categories."categoryID" = categories."ID" WHERE categories."companyID" = %s AND products.name = ANY(%s); """, - (company_id, item_name_list) + (company_id, item_name_list), ) product_mapping = cur.fetchall() return pd.DataFrame(product_mapping, columns=["productID", "item_name"]) + def get_store_ids(cur, company_id, store_name_list): """ Fetch store IDs for a given company and a list of store names. @@ -556,13 +581,11 @@ def get_store_ids(cur, company_id, store_name_list): FROM stores WHERE "companyID" = %s AND name = ANY(%s); """, - (company_id, list(store_name_list)) + (company_id, list(store_name_list)), ) store_id_mapping = cur.fetchall() return pd.DataFrame(store_id_mapping, columns=["storeID", "store_name"]) - - # %% ../nbs/30_ETL_db_writers.ipynb 21 def write_datapoints(sales: pd.DataFrame, company_id: int) -> None: """ @@ -583,7 +606,9 @@ def write_datapoints(sales: pd.DataFrame, company_id: int) -> None: with psycopg2.connect(db_credentials) as conn: with conn.cursor() as cur: # Step 1: Resolve SKU IDs - store_product_names = sales[["store_name", "item_name"]].drop_duplicates() + store_product_names = sales[ + ["store_name", "item_name"] + ].drop_duplicates() sku_ids = get_sku_ids(cur, store_product_names, company_id) # Merge SKU IDs into the sales DataFrame based on `store_name` and `item_name` @@ -600,16 +625,26 @@ def write_datapoints(sales: pd.DataFrame, company_id: int) -> None: # Check for unmatched rows if sales["skuID"].isnull().any(): - unmatched_skus = sales.loc[sales["skuID"].isnull(), ["store_name", "item_name"]].drop_duplicates() - raise ValueError(f"Unmatched SKUs: {unmatched_skus.to_dict(orient='records')}") + unmatched_skus = sales.loc[ + sales["skuID"].isnull(), ["store_name", "item_name"] + ].drop_duplicates() + raise ValueError( + f"Unmatched SKUs: {unmatched_skus.to_dict(orient='records')}" + ) if sales["dateID"].isnull().any(): - unmatched_dates = sales.loc[sales["dateID"].isnull(), "date"].unique() + unmatched_dates = sales.loc[ + sales["dateID"].isnull(), "date" + ].unique() raise ValueError(f"Unmatched dates: {unmatched_dates}") # Check for duplicate rows in `skuID` and `dateID` if sales[["skuID", "dateID"]].duplicated().any(): - duplicate_rows = sales[sales[["skuID", "dateID"]].duplicated(keep=False)] - raise ValueError(f"Duplicate rows found in the data: {duplicate_rows}") + duplicate_rows = sales[ + sales[["skuID", "dateID"]].duplicated(keep=False) + ] + raise ValueError( + f"Duplicate rows found in the data: {duplicate_rows}" + ) datapoints_data = sales[["skuID", "dateID"]] @@ -622,7 +657,7 @@ def write_datapoints(sales: pd.DataFrame, company_id: int) -> None: cur=cur, conn=conn, return_with_ids=True, - unique_columns=["skuID", "dateID"] + unique_columns=["skuID", "dateID"], ) return datapoint_ids @@ -632,7 +667,6 @@ def write_datapoints(sales: pd.DataFrame, company_id: int) -> None: # %% ../nbs/30_ETL_db_writers.ipynb 23 def write_sales(sales: pd.DataFrame, company_id, datapoint_ids) -> None: - """ This function writes the sales data to the database. @@ -640,37 +674,37 @@ def write_sales(sales: pd.DataFrame, company_id, datapoint_ids) -> None: write_SKU_date_specific_data( data=sales, - datapoint_ids = datapoint_ids, + datapoint_ids=datapoint_ids, variable_name="sales", variable_type=float, table_name="sales", company_id=company_id, ) -def write_prices(prices: pd.DataFrame, company_id, datapoint_ids) -> None: +def write_prices(prices: pd.DataFrame, company_id, datapoint_ids) -> None: """ This function writes the prices data to the database. """ write_SKU_date_specific_data( data=prices, - datapoint_ids = datapoint_ids, + datapoint_ids=datapoint_ids, variable_name="price", variable_type=float, table_name="prices", company_id=company_id, ) -def write_sold_flag(sold_flags: pd.DataFrame, company_id, datapoint_ids) -> None: +def write_sold_flag(sold_flags: pd.DataFrame, company_id, datapoint_ids) -> None: """ This function writes the sold flag data to the database. """ write_SKU_date_specific_data( data=sold_flags, - datapoint_ids = datapoint_ids, + datapoint_ids=datapoint_ids, variable_name="name", variable_type=str, table_name="flags", @@ -686,7 +720,7 @@ def write_SKU_date_specific_data( variable_type: callable, table_name: str, company_id: int, - name_in_df=None + name_in_df=None, ) -> None: """ Writes SKU and date-specific data to the database using the new `datapointID` schema. @@ -706,51 +740,75 @@ def write_SKU_date_specific_data( with conn.cursor() as cur: # Fetch `skuID` mappings logger.info("-- in write SKU date specific data -- getting sku IDs") - store_product_names = data[["store_name", "item_name"]].drop_duplicates() + store_product_names = data[ + ["store_name", "item_name"] + ].drop_duplicates() sku_mapping = get_sku_ids(cur, store_product_names, company_id) # Fetch `dateID` mappings logger.info("-- in write SKU date specific data -- getting date IDs") unique_dates = data["date"].drop_duplicates() date_mapping = get_date_ids(cur, unique_dates) - date_mapping["date"] = pd.to_datetime(date_mapping["date"], errors="coerce") - + date_mapping["date"] = pd.to_datetime( + date_mapping["date"], errors="coerce" + ) + # Merge `skuID` and `dateID` into the input data - logger.info("-- in write SKU date specific data -- merging sku IDs and date IDs") - data = data.merge(sku_mapping, on=["store_name", "item_name"], how="left") + logger.info( + "-- in write SKU date specific data -- merging sku IDs and date IDs" + ) + data = data.merge( + sku_mapping, on=["store_name", "item_name"], how="left" + ) data["date"] = pd.to_datetime(data["date"], errors="coerce") data = data.merge(date_mapping, on="date", how="left") data.drop(columns=["store_name", "item_name", "date"], inplace=True) # show ram usage of data - logger.info(f"-- in write SKU date specific data -- Memory usage of data: {data.memory_usage().sum() / 1024 / 1024 ** 2:.2f} GB") + logger.info( + f"-- in write SKU date specific data -- Memory usage of data: {data.memory_usage().sum() / 1024 / 1024 ** 2:.2f} GB" + ) # Check for unmatched mappings if data["skuID"].isnull().any(): - unmatched_skus = data.loc[data["skuID"].isnull(), ["store_name", "item_name"]].drop_duplicates() - raise ValueError(f"Unmatched SKUs: {unmatched_skus.to_dict(orient='records')}") + unmatched_skus = data.loc[ + data["skuID"].isnull(), ["store_name", "item_name"] + ].drop_duplicates() + raise ValueError( + f"Unmatched SKUs: {unmatched_skus.to_dict(orient='records')}" + ) if data["dateID"].isnull().any(): unmatched_dates = data.loc[data["dateID"].isnull(), "date"].unique() raise ValueError(f"Unmatched dates: {unmatched_dates}") - logger.info("-- in write SKU date specific data -- getting checking for duplicates") + logger.info( + "-- in write SKU date specific data -- getting checking for duplicates" + ) # Fetch `datapointID` for `skuID` and `dateID` combinations datapoint_combinations = data[["skuID", "dateID"]] # Check for duplicate combinations of `skuID` and `dateID` if datapoint_combinations.duplicated().any(): - duplicate_rows = datapoint_combinations[datapoint_combinations.duplicated(keep=False)] - raise ValueError(f"Duplicate rows found in the data: {duplicate_rows}") + duplicate_rows = datapoint_combinations[ + datapoint_combinations.duplicated(keep=False) + ] + raise ValueError( + f"Duplicate rows found in the data: {duplicate_rows}" + ) data.drop(columns=["storeID", "productID"], inplace=True) # Merge `datapointID` into the input data # rename column ID to datapointID in the datapoint_IDs - logger.info("-- in write SKU date specific data -- merging datapoint IDs") + logger.info( + "-- in write SKU date specific data -- merging datapoint IDs" + ) datapoint_ids = datapoint_ids.rename(columns={"ID": "datapointID"}) data = data.merge(datapoint_ids, on=["skuID", "dateID"], how="left") # Check for unmatched `datapointID` - logger.info("-- in write SKU date specific data -- checking for unmatched datapoints") + logger.info( + "-- in write SKU date specific data -- checking for unmatched datapoints" + ) # if data["datapointID"].isnull().any(): # # unmatched_datapoints = data.loc[data["datapointID"].isnull(), ["skuID", "dateID"]].drop_duplicates() # # raise ValueError(f"Unmatched datapoints: {unmatched_datapoints.to_dict(orient='records')}") @@ -758,11 +816,15 @@ def write_SKU_date_specific_data( data.drop(columns=["skuID", "dateID"], inplace=True) # Prepare data for insertion - logger.info("-- in write SKU date specific data -- preparing data for insertion") + logger.info( + "-- in write SKU date specific data -- preparing data for insertion" + ) if name_in_df is None: name_in_df = variable_name data_to_write = data[["datapointID", name_in_df]].copy() - data_to_write[name_in_df] = data_to_write[name_in_df].astype(variable_type) + data_to_write[name_in_df] = data_to_write[name_in_df].astype( + variable_type + ) # Insert data into the specified table logger.info("-- in write SKU date specific data -- inserting data") @@ -772,16 +834,17 @@ def write_SKU_date_specific_data( column_names=["datapointID", variable_name], types=[int, variable_type], cur=cur, - conn=conn + conn=conn, ) except Exception as e: logger.error(f"Error while writing {variable_name} data to the database: {e}") raise e - # %% ../nbs/30_ETL_db_writers.ipynb 26 -def get_sku_ids(cur, store_product_names: pd.DataFrame, company_id: int) -> pd.DataFrame: +def get_sku_ids( + cur, store_product_names: pd.DataFrame, company_id: int +) -> pd.DataFrame: """ Fetch skuIDs for given combinations of `store_name` and `item_name`. @@ -797,26 +860,32 @@ def get_sku_ids(cur, store_product_names: pd.DataFrame, company_id: int) -> pd.D store_ids = get_store_ids( cur=cur, company_id=company_id, - store_name_list=store_product_names["store_name"].unique().tolist() + store_name_list=store_product_names["store_name"].unique().tolist(), ) # Step 2: Resolve product IDs product_ids = get_product_ids( cur=cur, company_id=company_id, - item_name_list=store_product_names["item_name"].unique().tolist() + item_name_list=store_product_names["item_name"].unique().tolist(), ) # Step 3: Merge store and product IDs with input DataFrame - store_product_ids = store_product_names.merge(store_ids, on="store_name", how="left") + store_product_ids = store_product_names.merge( + store_ids, on="store_name", how="left" + ) store_product_ids = store_product_ids.merge(product_ids, on="item_name", how="left") # Check for unmatched rows if store_product_ids["storeID"].isnull().any(): - unmatched_stores = store_product_ids.loc[store_product_ids["storeID"].isnull(), "store_name"].unique() + unmatched_stores = store_product_ids.loc[ + store_product_ids["storeID"].isnull(), "store_name" + ].unique() raise ValueError(f"Unmatched store names: {unmatched_stores}") if store_product_ids["productID"].isnull().any(): - unmatched_products = store_product_ids.loc[store_product_ids["productID"].isnull(), "item_name"].unique() + unmatched_products = store_product_ids.loc[ + store_product_ids["productID"].isnull(), "item_name" + ].unique() raise ValueError(f"Unmatched product names: {unmatched_products}") # Step 4: Use a temporary table for efficient querying @@ -824,12 +893,14 @@ def get_sku_ids(cur, store_product_names: pd.DataFrame, company_id: int) -> pd.D temp_data = store_product_ids[["storeID", "productID"]] # Create temporary table - cur.execute(f""" + cur.execute( + f""" CREATE TEMP TABLE {temp_table_name} ( storeID INT, productID INT ) ON COMMIT DROP; - """) + """ + ) # Insert data into the temporary table psycopg2.extras.execute_batch( @@ -838,25 +909,29 @@ def get_sku_ids(cur, store_product_names: pd.DataFrame, company_id: int) -> pd.D INSERT INTO {temp_table_name} (storeID, productID) VALUES (%s, %s); """, - temp_data.values.tolist() + temp_data.values.tolist(), ) # Query sku_table using a JOIN - cur.execute(f""" + cur.execute( + f""" SELECT sku_table."ID" AS skuID, sku_table."storeID", sku_table."productID" FROM sku_table INNER JOIN {temp_table_name} ON sku_table."storeID" = {temp_table_name}.storeID AND sku_table."productID" = {temp_table_name}.productID; - """) + """ + ) # Fetch and return results result = cur.fetchall() sku_df = pd.DataFrame(result, columns=["skuID", "storeID", "productID"]) # Merge the original store_name and item_name back into the results - final_result = sku_df.merge(store_product_ids, on=["storeID", "productID"], how="left") - + final_result = sku_df.merge( + store_product_ids, on=["storeID", "productID"], how="left" + ) + return final_result[["skuID", "storeID", "productID", "store_name", "item_name"]] @@ -876,16 +951,20 @@ def get_datapoint_ids(cur, datapoint_combinations: pd.DataFrame) -> pd.DataFrame raise ValueError("Input DataFrame must contain 'skuID' and 'dateID' columns.") # Convert combinations to a list of tuples for use in the query - combinations_list = datapoint_combinations[["skuID", "dateID"]].drop_duplicates().values.tolist() + combinations_list = ( + datapoint_combinations[["skuID", "dateID"]].drop_duplicates().values.tolist() + ) try: # Create a temporary table to store the combinations - cur.execute(""" + cur.execute( + """ CREATE TEMP TABLE temp_datapoints ( "skuID" INTEGER, "dateID" INTEGER ) ON COMMIT DROP; - """) + """ + ) logger.info("Adding into table for temp_datapoints") # Insert the combinations into the temporary table @@ -895,16 +974,18 @@ def get_datapoint_ids(cur, datapoint_combinations: pd.DataFrame) -> pd.DataFrame INSERT INTO temp_datapoints ("skuID", "dateID") VALUES %s; """, - combinations_list + combinations_list, ) # Query for datapointIDs - cur.execute(""" + cur.execute( + """ SELECT d."ID", d."skuID", d."dateID" FROM datapoints d INNER JOIN temp_datapoints t ON d."skuID" = t."skuID" AND d."dateID" = t."dateID"; - """) + """ + ) # Fetch results and return as a DataFrame result = cur.fetchall() @@ -932,7 +1013,7 @@ def get_date_ids(cur, dates_list): FROM dates WHERE date = ANY(%s::date[]); """, - (list(dates_list),) + (list(dates_list),), ) date_id_mapping = cur.fetchall() return pd.DataFrame(date_id_mapping, columns=["dateID", "date"]) @@ -940,17 +1021,15 @@ def get_date_ids(cur, dates_list): # %% ../nbs/30_ETL_db_writers.ipynb 28 def write_time_region_features( time_region_features: pd.DataFrame, - name_description: [str, str], # containing name and description of the feature - company_id: int - + name_description: [str, str], # containing name and description of the feature + company_id: int, ): """ - This function writes data into the database whose values are specific to a + This function writes data into the database whose values are specific to a time-stamps and regions - - """ + """ db_credentials = get_db_credentials()["con"] @@ -958,7 +1037,6 @@ def write_time_region_features( with psycopg2.connect(db_credentials) as conn: with conn.cursor() as cur: - # add name and description to the time_region_features_description table cur.execute( """ @@ -972,7 +1050,7 @@ def write_time_region_features( UNION ALL SELECT "ID" FROM time_region_features_description WHERE "name" = %s; """, - (name_description[0], name_description[1], name_description[0]) + (name_description[0], name_description[1], name_description[0]), ) feature_id = cur.fetchone()[0] @@ -986,27 +1064,33 @@ def write_time_region_features( VALUES (%s, %s) ON CONFLICT DO NOTHING; """, - (company_id, feature_id) + (company_id, feature_id), ) - + # add features to the time_region_features table time_region_features = add_region_ids(time_region_features, cur) # add date features - time_region_features["date"] = pd.to_datetime(time_region_features["date"], errors="coerce") + time_region_features["date"] = pd.to_datetime( + time_region_features["date"], errors="coerce" + ) date_ids = get_date_ids(cur, time_region_features["date"].unique()) date_ids["date"] = pd.to_datetime(date_ids["date"], errors="coerce") - time_region_features = time_region_features.merge(date_ids, on="date", how="left") + time_region_features = time_region_features.merge( + date_ids, on="date", how="left" + ) cur.executemany( - """ INSERT INTO time_region_features ("dateID", "regionID", "trfID", "value") VALUES (%s, %s, %s, %s) ON CONFLICT ("dateID", "regionID", "trfID") DO NOTHING; """, - - (time_region_features[["dateID", "region_id", "trfID", "feature_value"]].values.tolist()) + ( + time_region_features[ + ["dateID", "region_id", "trfID", "feature_value"] + ].values.tolist() + ), ) conn.commit() diff --git a/inventory_foundation_sdk/etl_nodes.py b/inventory_foundation_sdk/etl_nodes.py index 07c0761..178e0ab 100644 --- a/inventory_foundation_sdk/etl_nodes.py +++ b/inventory_foundation_sdk/etl_nodes.py @@ -12,19 +12,17 @@ # %% ../nbs/31_ETL_nodes.ipynb 5 def input_output_node(*inputs): - """ This is a node for cases where the raw data can be directly passed through without processing steps. - + Accepts multiple inputs and returns them unpacked. If there's only one input, it returns the input itself. """ return inputs[0] if len(inputs) == 1 else inputs # %% ../nbs/31_ETL_nodes.ipynb 7 def convert_hirarchic_to_dict(categories: pd.DataFrame, single_leaf_level=True) -> dict: - """ - + This function converts a strictly hirarchic dataframe into a dictioary. Strictly hirarchic means that each column represents a hirarchy level, and each subcategory belongs to exactly one higher level category. In the dataframe, each subcategory belongs to exactly one higher level category. @@ -34,7 +32,7 @@ def convert_hirarchic_to_dict(categories: pd.DataFrame, single_leaf_level=True) Requirements: - IMPORTANT: This function is only for strictly hierarchical categories, i.e., each subcategory belongs to exactly one higher level category. - The categories must be in descending order (i.e., the first columns the highest level category, second column is the second highest level category, etc.) - - The column names can carry a name, if required (e.g., "category", "department", etc.). + - The column names can carry a name, if required (e.g., "category", "department", etc.). - The categories itself will be saved under generic levles ("1", "2", etc.), but the specific names will be returned in separate list for saving Inputs: @@ -42,7 +40,7 @@ def convert_hirarchic_to_dict(categories: pd.DataFrame, single_leaf_level=True) - single_leaf_level: A boolean that indicates if the categories dataframe has only one leaf level. If True, the function will return a dictionary with the leaf level as the last level. If False, leafs may be at different levels. Outputs: - - mappings: A dictionary with the levels as keys and a dictionary as values. + - mappings: A dictionary with the levels as keys and a dictionary as values. The dictionary has the category names as keys and list of parents. This means that the dictionary is more general than the dataframe and is the required input for the write_db_node. - category_level_names: A list of the column names of the categories dataframe. @@ -62,25 +60,31 @@ def convert_hirarchic_to_dict(categories: pd.DataFrame, single_leaf_level=True) level_cats = categories[category_level_names[i]].astype(str).unique() level_cats = {cat: None for cat in level_cats} else: - data = categories.iloc[:, i-1:i+1] + data = categories.iloc[:, i - 1 : i + 1] data = data.drop_duplicates() # Create a defaultdict with lists as the default value type level_cats = defaultdict(list) # Populate the dictionary using column index positions - for key, value in zip(data.iloc[:, 1], data.iloc[:, 0]): # 1 for the second column, 0 for the first column + for key, value in zip( + data.iloc[:, 1], data.iloc[:, 0] + ): # 1 for the second column, 0 for the first column if key not in level_cats: - level_cats[key] = [] # Initialize with an empty list and the leaf value - level_cats[key].append(value) # Append the value to the list of parents + level_cats[key] = ( + [] + ) # Initialize with an empty list and the leaf value + level_cats[key].append( + value + ) # Append the value to the list of parents # Convert to a regular dict if needed level_cats = dict(level_cats) - mappings[i+1] = level_cats + mappings[i + 1] = level_cats else: raise NotImplementedError("Currently only single leaf level is supported.") - + category_level_names = categories.columns.to_list() - + return mappings, category_level_names diff --git a/inventory_foundation_sdk/kedro_orchestration.py b/inventory_foundation_sdk/kedro_orchestration.py index fa23373..259ff5d 100644 --- a/inventory_foundation_sdk/kedro_orchestration.py +++ b/inventory_foundation_sdk/kedro_orchestration.py @@ -6,24 +6,24 @@ __all__ = ['verify_db_write_status'] # %% ../nbs/20_kedro_orchestration.ipynb 3 -#| export +# | export # %% ../nbs/20_kedro_orchestration.ipynb 5 def verify_db_write_status(*args: bool) -> bool: """ Consolidates the outputs of all specific functions that write to the database. - Each input represents whether a specific write operation was successful (True) - or not (False). The function returns True only if all inputs are True; + Each input represents whether a specific write operation was successful (True) + or not (False). The function returns True only if all inputs are True; otherwise, it returns False. This function can be used as a standalone node in a Kedro pipeline Args: - *args: A variable number of boolean arguments, each indicating the success + *args: A variable number of boolean arguments, each indicating the success of a specific database write operation. Returns: - bool: True if all operations were successful (all inputs are True), + bool: True if all operations were successful (all inputs are True), otherwise False. """ return all(args) diff --git a/nbs/30_ETL_db_writers.ipynb b/nbs/30_ETL_db_writers.ipynb index f6f2735..00bef05 100644 --- a/nbs/30_ETL_db_writers.ipynb +++ b/nbs/30_ETL_db_writers.ipynb @@ -168,7 +168,7 @@ " \n", " # Insert additional information for new entry\n", " if additional_info is not None:\n", - " for key, value in additional_info.to_dict().items():\n", + " for key, value in additional_info.items():\n", " # Add column if it doesn't exist\n", " cur.execute(\n", " f\"\"\"\n", @@ -1272,20 +1272,6 @@ "#| hide\n", "import nbdev; nbdev.nbdev_export()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8b1e050 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +kedro~=0.19.9 +psycopg2-binary~=2.9.1 +pandas>=1.5.0 +numpy>=1.21.0 +tqdm~=4.67.1 +black~=24.10.0 \ No newline at end of file diff --git a/settings.ini b/settings.ini index 16457a7..16388b9 100644 --- a/settings.ini +++ b/settings.ini @@ -8,7 +8,7 @@ lib_name = %(repo)s version = 0.0.3 min_python = 3.7 license = apache2 -black_formatting = False +black_formatting = True ### nbdev ### doc_path = _docs @@ -38,8 +38,8 @@ status = 3 user = d3group ### Optional ### -# requirements = fastcore pandas -# dev_requirements = +requirements = -r requirements.txt +dev_requirements = -r requirements.txt # console_scripts = # conda_user = # package_data =