diff --git a/smart_importer/entries.py b/smart_importer/entries.py index c336550..55298b9 100644 --- a/smart_importer/entries.py +++ b/smart_importer/entries.py @@ -9,14 +9,10 @@ def update_postings(transaction, accounts): if len(transaction.postings) != 1: return transaction - new_postings = [ - Posting(account, None, None, None, None, None) for account in accounts + new_postings = [transaction.postings[0]] + [ + Posting(account, None, None, None, None, None) + for account in accounts if account != transaction.postings[0].account ] - for posting in transaction.postings: - if posting.account in accounts: - new_postings[accounts.index(posting.account)] = posting - else: - new_postings.append(posting) return transaction._replace(postings=new_postings) diff --git a/tests/entries_test.py b/tests/entries_test.py new file mode 100644 index 0000000..cd023d7 --- /dev/null +++ b/tests/entries_test.py @@ -0,0 +1,73 @@ +# pylint: disable=missing-docstring +import textwrap + +from beancount.parser import parser + +from smart_importer.entries import update_postings + + +def test_update_postings_regular(): + + test_entries, _, _ = parser.parse_string( + textwrap.dedent(""" +2024-04-04 * "Supermarket ABC" "Groceries" + Assets:US:BofA:Checking -100.00 USD + """)) + + transaction = test_entries[0] + accounts = "Assets:US:BofA:Checking", "Expenses:Food" + + updated_transaction = update_postings(transaction, accounts) + + assert len(updated_transaction.postings) == 2 + + # Check if the first posting is unchanged + assert updated_transaction.postings[0] == transaction.postings[0] + + # Check if the remaining postings are created correctly + assert updated_transaction.postings[1].account == "Expenses:Food" + + +def test_update_postings_inverse(): + + test_entries, _, _ = parser.parse_string( + textwrap.dedent(""" +2024-04-04 * "Supermarket ABC" "Groceries" + Assets:US:BofA:Checking -100.00 USD + """)) + + transaction = test_entries[0] + accounts = "Expenses:Food", "Assets:US:BofA:Checking" + + updated_transaction = update_postings(transaction, accounts) + + assert len(updated_transaction.postings) == 2 + + # Check if the first posting is unchanged + assert updated_transaction.postings[0] == transaction.postings[0] + + # Check if the remaining postings are created correctly + assert updated_transaction.postings[1].account == "Expenses:Food" + + +def test_update_postings_multi(): + + test_entries, _, _ = parser.parse_string( + textwrap.dedent(""" +2024-04-04 * "Supermarket ABC" "Groceries" + Assets:US:BofA:Checking -100.00 USD + """)) + + transaction = test_entries[0] + accounts = "Assets:US:BofA:Checking", "Expenses:Food", "Expenses:Clothing" + + updated_transaction = update_postings(transaction, accounts) + + assert len(updated_transaction.postings) == 3 + + # Check if the first posting is unchanged + assert updated_transaction.postings[0] == transaction.postings[0] + + # Check if the remaining postings are created correctly + assert updated_transaction.postings[1].account == "Expenses:Food" + assert updated_transaction.postings[2].account == "Expenses:Clothing"