Skip to content

Commit

Permalink
Merge pull request #83 from masa-su/bugfix
Browse files Browse the repository at this point in the history
Fix the way to raise errors
  • Loading branch information
masa-su authored Aug 6, 2019
2 parents b1fed22 + e2b8a80 commit fdccbf3
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 38 deletions.
30 changes: 15 additions & 15 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def get_params(self, params_dict={}):
{'scale': tensor(1.), 'loc': tensor([0.])}
"""
raise NotImplementedError
raise NotImplementedError()

def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=True,
reparam=False):
Expand Down Expand Up @@ -334,7 +334,7 @@ def sample(self, x_dict={}, batch_n=None, sample_shape=torch.Size(), return_all=
0.3686, 0.6311, -1.1208, 0.3656, -0.6683]])}
"""
raise NotImplementedError
raise NotImplementedError()

def sample_mean(self, x_dict={}):
"""Return the mean of the distribution.
Expand Down Expand Up @@ -365,7 +365,7 @@ def sample_mean(self, x_dict={}):
1.2810, -0.6681]])
"""
raise NotImplementedError
raise NotImplementedError()

def sample_variance(self, x_dict={}):
"""Return the variance of the distribution.
Expand Down Expand Up @@ -395,7 +395,7 @@ def sample_variance(self, x_dict={}):
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
"""
raise NotImplementedError
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
"""Giving variables, this method returns values of log-pdf.
Expand Down Expand Up @@ -435,7 +435,7 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
tensor([-21.5251])
"""
raise NotImplementedError
raise NotImplementedError()

def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
"""Giving variables, this method returns values of entropy.
Expand Down Expand Up @@ -474,7 +474,7 @@ def get_entropy(self, x_dict={}, sum_features=True, feature_dims=None):
tensor([14.1894])
"""
raise NotImplementedError
raise NotImplementedError()

def log_prob(self, sum_features=True, feature_dims=None):
"""Return an instance of :class:`pixyz.losses.LogProb`.
Expand Down Expand Up @@ -556,7 +556,7 @@ def prob(self, sum_features=True, feature_dims=None):
def forward(self, *args, **kwargs):
"""When this class is inherited by DNNs, this method should be overrided."""

raise NotImplementedError
raise NotImplementedError()

def replace_var(self, **replace_dict):
"""Return an instance of :class:`pixyz.distributions.ReplaceVarDistribution`.
Expand Down Expand Up @@ -652,13 +652,13 @@ def _set_buffers(self, **params_dict):
if params_dict[key] in self._cond_var:
self.replace_params_dict[params_dict[key]] = key
else:
raise ValueError
raise ValueError()
elif isinstance(params_dict[key], torch.Tensor):
features = params_dict[key]
features_checked = self._check_features_shape(features)
self.register_buffer(key, features_checked)
else:
raise ValueError
raise ValueError()

def _check_features_shape(self, features):
# scalar
Expand All @@ -678,12 +678,12 @@ def _check_features_shape(self, features):
@property
def params_keys(self):
"""list: Return the list of parameter names for this distribution."""
raise NotImplementedError
raise NotImplementedError()

@property
def distribution_torch_class(self):
"""Return the class of PyTorch distribution."""
raise NotImplementedError
raise NotImplementedError()

@property
def dist(self):
Expand Down Expand Up @@ -712,7 +712,7 @@ def set_dist(self, x_dict={}, sampling=False, batch_n=None, **kwargs):
"""
params = self.get_params(x_dict, **kwargs)
if set(self.params_keys) != set(params.keys()):
raise ValueError
raise ValueError()

self._dist = self.distribution_torch_class(**params)

Expand All @@ -724,7 +724,7 @@ def set_dist(self, x_dict={}, sampling=False, batch_n=None, **kwargs):
elif batch_shape[0] == batch_n:
return
else:
raise ValueError
raise ValueError()

def get_sample(self, reparam=False, sample_shape=torch.Size()):
"""Get a sample_shape shaped sample from :attr:`dist`.
Expand All @@ -746,7 +746,7 @@ def get_sample(self, reparam=False, sample_shape=torch.Size()):
if reparam:
try:
_samples = self.dist.rsample(sample_shape=sample_shape)
except NotImplementedError:
except NotImplementedError():
raise ValueError("You cannot use the re-parameterization trick for this distribution.")
else:
_samples = self.dist.sample(sample_shape=sample_shape)
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def __init__(self, p, replace_dict):
all_vars = _cond_var + _var

if not (set(replace_dict.keys()) <= set(all_vars)):
raise ValueError
raise ValueError()

_replace_inv_cond_var_dict = {replace_dict[var]: var for var in _cond_var if var in replace_dict.keys()}
_replace_inv_dict = {value: key for key, value in replace_dict.items()}
Expand Down
4 changes: 2 additions & 2 deletions pixyz/distributions/exponential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def set_dist(self, x_dict={}, sampling=True, batch_n=None, **kwargs):
elif batch_shape[0] == batch_n:
return
else:
raise ValueError
raise ValueError()


class FactorizedBernoulli(Bernoulli):
Expand Down Expand Up @@ -169,7 +169,7 @@ def set_dist(self, x_dict={}, sampling=True, batch_n=None, **kwargs):
elif batch_shape[0] == batch_n:
return
else:
raise ValueError
raise ValueError()

def sample_mean(self, x_dict={}):
self.set_dist(x_dict, sampling=False)
Expand Down
4 changes: 2 additions & 2 deletions pixyz/distributions/mixture_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, distributions, prior, name="p"):
"""
if not isinstance(distributions, list):
raise ValueError
raise ValueError()
else:
distributions = nn.ModuleList(distributions)

Expand Down Expand Up @@ -236,7 +236,7 @@ def distribution_name(self):
return "Mixture Model (Posterior)"

def sample(self, *args, **kwargs):
raise NotImplementedError
raise NotImplementedError()

def get_log_prob(self, x_dict, **kwargs):
# log p(z|x) = log p(x, z) - log p(x)
Expand Down
14 changes: 7 additions & 7 deletions pixyz/distributions/poe.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ def __init__(self, p=[], name="p", features_shape=torch.Size()):
"""
p = tolist(p)
if len(p) == 0:
raise ValueError
raise ValueError()

var = p[0].var
cond_var = []

for _p in p:
if _p.var != var:
raise ValueError
raise ValueError()

if _p.distribution_name != "Normal":
raise ValueError
raise ValueError()

cond_var += _p.cond_var

Expand Down Expand Up @@ -225,13 +225,13 @@ def _check_input(self, x, var=None):
return checked_x

def log_prob(self, sum_features=True, feature_dims=None):
raise NotImplementedError
raise NotImplementedError()

def prob(self, sum_features=True, feature_dims=None):
raise NotImplementedError
raise NotImplementedError()

def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):
raise NotImplementedError
raise NotImplementedError()


class ElementWiseProductOfNormal(ProductOfNormal):
Expand Down Expand Up @@ -302,7 +302,7 @@ def __init__(self, p, name="p", features_shape=torch.Size()):
"""
if len(p.cond_var) != 1:
raise ValueError
raise ValueError()

super().__init__(p=p, name=name, features_shape=features_shape)

Expand Down
6 changes: 3 additions & 3 deletions pixyz/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, p, q, discriminator, input_var=None,
elif len(q.input_var) > 0:
self.input_dist = q
else:
raise NotImplementedError
raise NotImplementedError()

super().__init__(p, q, input_var=input_var)

Expand Down Expand Up @@ -47,7 +47,7 @@ def d_loss(self, y_p, y_q, batch_n):
torch.Tensor
"""
raise NotImplementedError
raise NotImplementedError()

def g_loss(self, y_p, y_q, batch_n):
"""Evaluate a generator loss given outputs of the discriminator.
Expand All @@ -66,7 +66,7 @@ def g_loss(self, y_p, y_q, batch_n):
torch.Tensor
"""
raise NotImplementedError
raise NotImplementedError()

def train(self, train_x_dict, **kwargs):
"""Train the evaluation metric (discriminator).
Expand Down
10 changes: 5 additions & 5 deletions pixyz/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def input_var(self):
@property
@abc.abstractmethod
def _symbol(self):
raise NotImplementedError
raise NotImplementedError()

@property
def loss_text(self):
Expand Down Expand Up @@ -214,7 +214,7 @@ def eval(self, x_dict={}, return_dict=False, **kwargs):

@abc.abstractmethod
def _get_eval(self, x_dict, **kwargs):
raise NotImplementedError
raise NotImplementedError()


class ValueLoss(Loss):
Expand Down Expand Up @@ -264,7 +264,7 @@ class Parameter(Loss):
"""
def __init__(self, input_var):
if not isinstance(input_var, str):
raise ValueError
raise ValueError()
self._input_var = tolist(input_var)

def _get_eval(self, x_dict={}, **kwargs):
Expand Down Expand Up @@ -439,14 +439,14 @@ def __init__(self, loss1):
_input_var = []

if isinstance(loss1, type(None)):
raise ValueError
raise ValueError()

if isinstance(loss1, Loss):
_input_var = deepcopy(loss1.input_var)
elif isinstance(loss1, numbers.Number):
loss1 = ValueLoss(loss1)
else:
raise ValueError
raise ValueError()

self._input_var = _input_var
self.loss1 = loss1
Expand Down
6 changes: 3 additions & 3 deletions pixyz/losses/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def __init__(self, p, q, input_var=None, kernel="gaussian", **kernel_params):
elif len(q.input_var) > 0:
self.input_dist = q
else:
raise NotImplementedError
raise NotImplementedError()

if kernel == "gaussian":
self.kernel = gaussian_rbf_kernel
elif kernel == "inv-multiquadratic":
self.kernel = inverse_multiquadratic_rbf_kernel
else:
raise NotImplementedError
raise NotImplementedError()

self.kernel_params = kernel_params

Expand Down Expand Up @@ -97,7 +97,7 @@ def pairwise_distance_matrix(x, y, metric="euclidean"):
if metric == "euclidean":
return torch.sum((x[:, None, :] - y[None, :, :]) ** 2, dim=-1)

raise NotImplementedError
raise NotImplementedError()


def gaussian_rbf_kernel(x, y, sigma_sqr=2., **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pixyz/losses/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, p, q, metric=PairwiseDistance(p=2), input_var=None):
elif len(q.input_var) > 0:
self.input_dist = q
else:
raise NotImplementedError
raise NotImplementedError()

self.metric = metric

Expand Down

0 comments on commit fdccbf3

Please sign in to comment.