diff --git a/smart_importer/predictor.py b/smart_importer/predictor.py index 08b171f..ffe9360 100644 --- a/smart_importer/predictor.py +++ b/smart_importer/predictor.py @@ -53,10 +53,12 @@ def __init__( predict=True, overwrite=False, string_tokenizer: Callable[[str], list] | None = None, + denylist_accounts: list[str] = [] ): super().__init__() self.training_data = None self.open_accounts: dict[str, str] = {} + self.denylist_accounts = denylist_accounts self.pipeline: Pipeline | None = None self.is_fitted = False self.lock = threading.Lock() @@ -132,6 +134,8 @@ def training_data_filter(self, txn): for pos in txn.postings: if pos.account not in self.open_accounts: return False + if pos.account in self.denylist_accounts: + return False if self.account == pos.account: found_import_account = True return found_import_account or not self.account diff --git a/tests/predictors_test.py b/tests/predictors_test.py index c825017..d3280c4 100644 --- a/tests/predictors_test.py +++ b/tests/predictors_test.py @@ -31,6 +31,9 @@ 2017-01-13 * "Gas Quick" Assets:US:BofA:Checking -17.45 USD + +2017-01-14 * "Axe Throwing with Joe" + Assets:US:BofA:Checking -13.37 USD """ ) @@ -42,6 +45,7 @@ 2016-01-01 open Expenses:Auto:Gas USD 2016-01-01 open Expenses:Food:Groceries USD 2016-01-01 open Expenses:Food:Restaurant USD +2016-01-01 open Expenses:Denylisted USD 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -2.50 USD @@ -92,6 +96,11 @@ 2016-01-12 * "Gas Quick" Assets:US:BofA:Checking -24.09 USD Expenses:Auto:Gas + +2016-01-08 * "Axe Throwing with Joe" + Assets:US:BofA:Checking -38.36 USD + Expenses:Denylisted + """ ) @@ -104,6 +113,7 @@ "Gimme Coffee", "Uncle Boons", None, + None, ] ACCOUNT_PREDICTIONS = [ @@ -115,8 +125,10 @@ "Expenses:Food:Coffee", "Expenses:Food:Groceries", "Expenses:Auto:Gas", + "Expenses:Food:Groceries", ] +DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"] class BasicTestImporter(ImporterProtocol): def extract(self, file, existing_entries=None): @@ -132,7 +144,7 @@ def file_account(self, file): PAYEE_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPayees()]) -POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings()]) +POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings(denylist_accounts=denylist_accounts)]) def test_empty_training_data():