Skip to content

Commit

Permalink
Fix track matching
Browse files Browse the repository at this point in the history
I had previously tested the `munkres` -> `lapjv` replacement
extensively, so I was today surprised to find that nothing gets matched
correctly when I tried importing some new tracks.

On the other hand I now remember making a small adjustment in the logic
to make autotagging tests pass which is when I introduced a bug: I did
not realize that `lapjv` returns index '-1' for each unmatched item.

This issue did not get caught by tests because this 'unmatched' item
index '-1' anecdotally ended up pointing to the last (expected) item in
the test making it pass.

This commit adjusts the aforementioned test to catch this issue and
fixes the logic to correctly identify unmatched tracks.
  • Loading branch information
snejus committed Dec 29, 2024
1 parent faf7529 commit 34b1212
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 6 additions & 2 deletions beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,15 @@ def assign_items(
costs = [[float(track_distance(i, t)) for t in tracks] for i in items]
# Find a minimum-cost bipartite matching.
log.debug("Computing track assignment...")
cost, _, assigned_idxs = lap.lapjv(np.array(costs), extend_cost=True)
cost, _, assigned_item_idxs = lap.lapjv(np.array(costs), extend_cost=True)
log.debug("...done.")

# Produce the output matching.
mapping = {items[i]: tracks[t] for (t, i) in enumerate(assigned_idxs)}
mapping = {
items[iidx]: tracks[tidx]
for tidx, iidx in enumerate(assigned_item_idxs)
if iidx != -1
}
extra_items = list(set(items) - mapping.keys())
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
Expand Down
6 changes: 3 additions & 3 deletions test/test_autotag.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def test_order_works_with_invalid_track_numbers(self):
def test_order_works_with_missing_tracks(self):
items = []
items.append(self.item("one", 1))
items.append(self.item("three", 3))
items.append(self.item("two", 2))
trackinfo = []
trackinfo.append(TrackInfo(title="one"))
trackinfo.append(TrackInfo(title="two"))
Expand All @@ -560,8 +560,8 @@ def test_order_works_with_missing_tracks(self):
items, trackinfo
)
assert extra_items == []
assert extra_tracks == [trackinfo[1]]
assert mapping == {items[0]: trackinfo[0], items[1]: trackinfo[2]}
assert extra_tracks == [trackinfo[2]]
assert mapping == {items[0]: trackinfo[0], items[1]: trackinfo[1]}

def test_order_works_with_extra_tracks(self):
items = []
Expand Down

0 comments on commit 34b1212

Please sign in to comment.