Skip to content

Commit

Permalink
Faster symmetry breaking (#1502)
Browse files Browse the repository at this point in the history
* Faster symmetry breaking

* Address comments
  • Loading branch information
brahmaneya authored Oct 30, 2019
1 parent 36a5456 commit 056bf91
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 78 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ tqdm>=4.33.0,<5.0.0
# Internal models
scikit-learn>=0.20.2,<0.22.0
torch>=1.1.0,<1.2.0
munkres==1.1.2

# LF dependency learning
networkx>=2.2,<2.4
Expand Down
8 changes: 5 additions & 3 deletions scripts/check_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def parse_setup() -> Tuple[PackagesType, PackagesType, Set[str], Set[str]]:
def main() -> int:
exit_code = 0

requirements_essential, requirements_other, requirements_duplicate = (
parse_requirements()
)
(
requirements_essential,
requirements_other,
requirements_duplicate,
) = parse_requirements()
requirements_all = dict(requirements_essential, **requirements_other)
setup_essential, setup_test, essential_duplicates, test_duplicates = parse_setup()

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
packages=find_packages(),
include_package_data=True,
install_requires=[
"munkres==1.1.2",
"numpy>=1.16.0,<2.0.0",
"scipy>=1.2.0,<2.0.0",
"pandas>=0.25.0,<0.26.0",
Expand Down
89 changes: 33 additions & 56 deletions snorkel/labeling/model/label_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
import pickle
import random
from collections import Counter
from itertools import chain, permutations
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Union
from collections import Counter, defaultdict
from itertools import chain
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from munkres import Munkres # type: ignore

from snorkel.analysis import Scorer
from snorkel.labeling.analysis import LFAnalysis
Expand Down Expand Up @@ -755,33 +756,6 @@ def _clamp_params(self) -> None:
mu_eps = min(0.01, 1 / 10 ** np.ceil(np.log10(self.n)))
self.mu.data = self.mu.clamp(mu_eps, 1 - mu_eps) # type: ignore

def _count_accurate_lfs(self, mu: np.ndarray) -> int:
r"""Count the number of LFs that are estimated to be better than random.
Return the number of LFs are estimated to be more accurate than not when not
abstaining, i.e., where
P(\lf = Y) > P(\lf != Y, \lf != -1).
Parameters
----------
mu
An [m * k, k] np.ndarray with entries in [0, 1]
Returns
-------
int
Number of LFs better than random
"""
P = self.P.cpu().detach().numpy()
cprobs = self._get_conditional_probs(mu)
count = 0
for i in range(self.m):
probs = cprobs[i, 1:] @ P
if 2 * np.diagonal(probs).sum() - probs.sum() > 0:
count += 1
return count

def _break_col_permutation_symmetry(self) -> None:
r"""Heuristically choose amongst (possibly) several valid mu values.
Expand All @@ -794,38 +768,41 @@ def _break_col_permutation_symmetry(self) -> None:
2. diag(O) = sum(mu @ P, axis=1)
Then any column permutation matrix Z that commutes with P will also equivalently
satisfy these objectives, and thus is an equally valid (symmetric) solution.
Therefore, we select the solution where the most LFs are estimated to be more
accurate than not when not abstaining, i.e., where for the majority of LFs,
P(\lf = Y) > P(\lf != Y, \lf != -1).
Therefore, we select the solution that maximizes the summed probability of the
LFs being accurate when not abstaining.
This is the standard assumption we have made in algorithmic and theoretical
work to date. Note however that this is not the only possible heuristic /
assumption that we could use, and in practice this may require further
iteration here.
\sum_lf \sum_{y=1}^{cardinality} P(\lf = y, Y = y)
"""
mu = self.mu.cpu().detach().numpy()
P = self.P.cpu().detach().numpy()
d, k = mu.shape

# Iterate through the possible perumation matrices and track heuristic scores
Zs = []
scores = []
for idxs in permutations(range(k)):
Z = np.eye(k)[:, idxs]
Zs.append(Z)

# If Z and P commute, get heuristic score, else skip
if np.allclose(Z @ P, P @ Z):
scores.append(self._count_accurate_lfs(mu @ Z))
else:
scores.append(-1)

# Set mu according to highest-scoring permutation
# We want to maximize the sum of diagonals of matrices for each LF. So
# we start by computing the sum of conditional probabilities here.
probs_sum = sum([mu[i : i + k] for i in range(0, self.m * k, k)]) @ P

munkres_solver = Munkres()
Z = np.zeros([k, k])

# Compute groups of indicess with equal prior in P.
groups: DefaultDict[float, List[int]] = defaultdict(list)
for i, f in enumerate(P.diagonal()):
groups[np.around(f, 3)].append(i)
for group in groups.values():
if len(group) == 1:
Z[group[0], group[0]] = 1.0 # Identity permutation
continue
# Compute submatrix corresponding to the group.
probs_proj = probs_sum[[[g] for g in group], group]
# Use the Munkres algorithm to find the optimal permutation.
# We use minus because we want to maximize diagonal sum, not minimize,
# and transpose because we want to permute columns, not rows.
permutation_pairs = munkres_solver.compute(-probs_proj.T)
for i, j in permutation_pairs:
Z[group[i], group[j]] = 1.0

# Set mu according to permutation
self.mu = nn.Parameter(
torch.Tensor(mu @ Zs[np.argmax(scores)]).to( # type: ignore
self.config.device
)
torch.Tensor(mu @ Z).to(self.config.device) # type: ignore
)

def fit(
Expand Down
48 changes: 29 additions & 19 deletions test/labeling/model/test_label_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def test_set_mu_eps(self):
label_model.fit(L, mu_eps=mu_eps)
self.assertAlmostEqual(label_model.get_conditional_probs()[0, 1, 0], mu_eps)

def test_count_accurate_lfs(self):
def test_symmetry_breaking(self):
mu = np.array(
[
# LF 0
Expand All @@ -431,52 +431,62 @@ def test_count_accurate_lfs(self):
[0.25, 0.75],
]
)
mu = mu[:, [1, 0]]

# First test: Two "good" LFs
label_model = LabelModel(verbose=False)
label_model._set_class_balance(None, None)
label_model.m = 3
self.assertEqual(label_model._count_accurate_lfs(mu), 2)
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.75)

# Second test: Now they should all be "good" due to class balance, since we're
# counting accuracy (not conditional probabilities)
# Test with non-uniform class balance
# It should not consider the "correct" permutation as does not commute now
label_model = LabelModel(verbose=False)
label_model._set_class_balance([0.9, 0.1], None)
label_model.m = 3
self.assertEqual(label_model._count_accurate_lfs(mu), 3)
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.25)

def test_symmetry_breaking(self):
def test_symmetry_breaking_multiclass(self):
mu = np.array(
[
# LF 0
[0.75, 0.25],
[0.25, 0.75],
[0.75, 0.15, 0.1],
[0.20, 0.75, 0.3],
[0.05, 0.10, 0.6],
# LF 1
[0.25, 0.75],
[0.15, 0.25],
[0.25, 0.55, 0.3],
[0.15, 0.45, 0.4],
[0.20, 0.00, 0.3],
# LF 2
[0.75, 0.25],
[0.25, 0.75],
[0.5, 0.15, 0.2],
[0.3, 0.65, 0.2],
[0.2, 0.20, 0.6],
]
)
mu = mu[:, [1, 0]]
mu = mu[:, [1, 2, 0]]

# First test: Two "good" LFs
label_model = LabelModel(verbose=False)
label_model = LabelModel(cardinality=3, verbose=False)
label_model._set_class_balance(None, None)
label_model.m = 3
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.75)
self.assertEqual(label_model.mu.data[1, 1], 0.75)

# Test with non-uniform class balance
# It should not consider the "correct" permutation as does not commute now
label_model = LabelModel(verbose=False)
label_model._set_class_balance([0.9, 0.1], None)
# It should not consider the "correct" permutation as it does not commute
label_model = LabelModel(cardinality=3, verbose=False)
label_model._set_class_balance([0.7, 0.2, 0.1], None)
label_model.m = 3
label_model.mu = nn.Parameter(torch.from_numpy(mu))
label_model._break_col_permutation_symmetry()
self.assertEqual(label_model.mu.data[0, 0], 0.25)
self.assertEqual(label_model.mu.data[0, 0], 0.15)
self.assertEqual(label_model.mu.data[1, 1], 0.3)


@pytest.mark.complex
Expand Down Expand Up @@ -528,7 +538,7 @@ def test_label_model_sparse(self) -> None:

# Test predicted labels *only on non-abstained data points*
Y_pred = label_model.predict(L, tie_break_policy="abstain")
idx, = np.where(Y_pred != -1)
(idx,) = np.where(Y_pred != -1)
acc = np.where(Y_pred[idx] == Y[idx], 1, 0).sum() / len(idx)
self.assertGreaterEqual(acc, 0.65)

Expand Down

0 comments on commit 056bf91

Please sign in to comment.