Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge new features of master into current feed forward for development #158

Open
wants to merge 22 commits into
base: graph-inn-feed-forward
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d71863b
restructured spline classes, added 1d spline. Added conditional netwo…
RussellALA Feb 21, 2023
d215734
fixed constrain parameters after BinnedSplineCoupling refactor, fixed…
RussellALA Feb 28, 2023
865fb25
renamed RationalQuadraticSpline_1D to ElementwiseRationalQuadraticSpl…
RussellALA Mar 2, 2023
f96753c
fixed conflicts in tests/test_splines.py
RussellALA Mar 2, 2023
af21851
past push got reverted somehow?
RussellALA Mar 2, 2023
ebba9d2
Merge pull request #151 from RussellALA/master
RussellALA Mar 3, 2023
50dde7b
Generate Python docs from ebba9d270f261cf550dcab431980ac055afbd835
github-actions[bot] Mar 3, 2023
16d8e3c
feat: plot graphInn structure
zimea Mar 7, 2023
4cb7822
feat: plot graphInn structure
zimea Mar 7, 2023
892e1ae
Merge branch 'feat/plot_graphINN' of github.com:vislearn/FrEIA into f…
zimea Mar 7, 2023
52979ec
refactor: cleanup
zimea Mar 7, 2023
bf3af23
fix: add dependency setup.py
zimea Mar 7, 2023
9f305e0
fix: exception handling
zimea Mar 7, 2023
7f63827
Optional validate_args for standard normal
fdraxler Mar 8, 2023
e71cb7c
Generate Python docs from 7f63827f3a7c33ec6acde0fe79a6a6da5cf6c8fe
github-actions[bot] Mar 8, 2023
0f65160
fixed section size completion in FrEIA.modules.Split
RussellALA Mar 8, 2023
22f382f
fixed conflicts
RussellALA Mar 8, 2023
babc92d
fix: cehck graphviz backend installed
zimea Mar 8, 2023
b179f92
fix: adapt tets
zimea Mar 8, 2023
1ec30cd
Merge pull request #157 from vislearn/feat/plot-graphINN
zimea Mar 8, 2023
32d753b
Generate Python docs from 1ec30cdaa4c2eba8f7460c464ab23dd4cdf1d7e8
github-actions[bot] Mar 8, 2023
f4e5f93
remove leftover graph_inn.py
RussellALA Mar 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions FrEIA/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@


class StandardNormalDistribution(Independent):
def __init__(self, *event_shape: int, device=None, dtype=None):
def __init__(self, *event_shape: int, device=None, dtype=None, validate_args=True):
loc = torch.tensor(0., device=device, dtype=dtype).repeat(event_shape)
scale = torch.tensor(1., device=device, dtype=dtype).repeat(event_shape)

super().__init__(Normal(loc, scale), len(event_shape))
super().__init__(
Normal(loc, scale, validate_args=validate_args),
len(event_shape),
validate_args=validate_args
)
3 changes: 3 additions & 0 deletions FrEIA/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* GINCouplingBlock
* AffineCouplingOneSided
* ConditionalAffineTransform
* RationalQuadraticSpline

Reshaping:

Expand Down Expand Up @@ -43,6 +44,7 @@
* LearnedElementwiseScaling
* OrthogonalTransform
* HouseholderPerm
* ElementwiseRationalQuadraticSpline

Fixed (non-learned) transforms:

Expand Down Expand Up @@ -106,4 +108,5 @@
'GaussianMixtureModel',
'LinearSpline',
'RationalQuadraticSpline',
'ElementwiseRationalQuadraticSpline',
]
14 changes: 7 additions & 7 deletions FrEIA/modules/graph_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def __init__(self,
else:
if isinstance(section_sizes, int):
assert section_sizes < l_dim, "'section_sizes' too large"
else:
assert isinstance(section_sizes, (list, tuple)), \
"'section_sizes' must be either int or list/tuple of int"
assert sum(section_sizes) <= l_dim, "'section_sizes' too large"
if sum(section_sizes) < l_dim:
warnings.warn("'section_sizes' too small, adding additional section")
section_sizes = list(section_sizes).append(l_dim - sum(section_sizes))
section_sizes = (section_sizes,)
assert isinstance(section_sizes, (list, tuple)), \
"'section_sizes' must be either int or list/tuple of int"
assert sum(section_sizes) <= l_dim, "'section_sizes' too large"
if sum(section_sizes) < l_dim:
warnings.warn("'section_sizes' too small, adding additional section")
section_sizes = list(section_sizes) + [l_dim - sum(section_sizes)]
self.split_size_or_sections = section_sizes

def forward(self, x, rev=False, jac=True):
Expand Down
87 changes: 50 additions & 37 deletions FrEIA/modules/splines/binned.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,63 @@
from itertools import chain

from FrEIA.modules.coupling_layers import _BaseCouplingBlock
from FrEIA.modules.base import InvertibleModule
from FrEIA import utils


class BinnedSpline(_BaseCouplingBlock):
def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None,
split_len: Union[float, int] = 0.5, **kwargs) -> None:
if dims_c is None:
dims_c = []

super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len)


self.spline_base = BinnedSplineBase(dims_in, dims_c, **kwargs)

num_params = sum(self.spline_base.parameter_counts.values())
self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params)
self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params)

def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""
The full coupling consists of:
1. Querying the parameter tensor from the subnetwork
2. Splitting this tensor into the semantic parameters
3. Constraining the parameters
4. Performing the actual spline for each bin, given the parameters
"""
parameters = self.subnet1(u2)
parameters = self.spline_base.split_parameters(parameters, self.split_len1)
parameters = self.constrain_parameters(parameters)

return self.spline_base.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev)

def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
parameters = self.subnet2(u1)
parameters = self.spline_base.split_parameters(parameters, self.split_len2)
parameters = self.constrain_parameters(parameters)

return self.spline_base.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev)

def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.spline_base.constrain_parameters(parameters)

class BinnedSplineBase(InvertibleModule):
"""
Base Class for Splines
Implements input-binning, where bin knots are jointly predicted along with spline parameters
by a non-invertible coupling subnetwork
"""

def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, split_len: Union[float, int] = 0.5,
bins: int = 10, parameter_counts: Dict[str, int] = None, min_bin_sizes: Tuple[float] = (0.1, 0.1),
default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None:
def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None,
min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None:
"""
Args:
bins: number of bins to use
Expand All @@ -35,7 +79,7 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp
if parameter_counts is None:
parameter_counts = {}

super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len)
super().__init__(dims_in, dims_c)

assert bins >= 1, "need at least one bin"
assert all(s >= 0 for s in min_bin_sizes), "minimum bin size cannot be negative"
Expand All @@ -61,40 +105,9 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp
# merge parameter counts with child classes
self.parameter_counts = {**default_parameter_counts, **parameter_counts}

num_params = sum(self.parameter_counts.values())
self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params)
self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params)

def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""
The full coupling consists of:
1. Querying the parameter tensor from the subnetwork
2. Splitting this tensor into the semantic parameters
3. Constraining the parameters
4. Performing the actual spline for each bin, given the parameters
"""
parameters = self.subnet1(u2)
parameters = self.split_parameters(parameters, self.split_len1)
parameters = self.constrain_parameters(parameters)

return self.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev)

def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
parameters = self.subnet2(u1)
parameters = self.split_parameters(parameters, self.split_len2)
parameters = self.constrain_parameters(parameters)

return self.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev)

def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str, torch.Tensor]:
"""
Split network output into semantic parameters, as given by self.parameter_counts
Split parameter tensor into semantic parameters, as given by self.parameter_counts
"""
parameters = parameters.movedim(1, -1)
parameters = parameters.reshape(*parameters.shape[:-1], split_len, -1)
Expand Down
68 changes: 65 additions & 3 deletions FrEIA/modules/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import Dict, List, Tuple
from typing import Dict, Tuple, Callable, Iterable, List

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

from .binned import BinnedSpline
from .binned import BinnedSplineBase, BinnedSpline


class RationalQuadraticSpline(BinnedSpline):
def __init__(self, *args, bins: int = 10, **kwargs):
# parameter constraints count
# 1. the derivative at the edge of each inner bin positive #bins - 1
super().__init__(*args, **kwargs, bins=bins, parameter_counts={"deltas": bins - 1})
super().__init__(*args, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs)

def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
parameters = super().constrain_parameters(parameters)
Expand Down Expand Up @@ -44,6 +45,65 @@ def _spline2(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bo
return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev)


class ElementwiseRationalQuadraticSpline(BinnedSplineBase):
def __init__(self, dims_in, dims_c=[], subnet_constructor: Callable = None,
bins: int = 10, **kwargs) -> None:
super().__init__(dims_in, dims_c, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs)

self.channels = dims_in[0][0]
self.condition_length = sum([dims_c[i][0] for i in range(len(dims_c))])
self.conditional = len(dims_c) > 0

num_params = sum(self.parameter_counts.values())


if self.conditional:
self.subnet = subnet_constructor(self.condition_length, self.channels * num_params)
else:
self.spline_parameters = nn.Parameter(torch.zeros(self.channels * num_params, *dims_in[0][1:]))


def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
parameters = super().constrain_parameters(parameters)
# we additionally want positive derivatives to preserve monotonicity
# the derivative must also match the tails at the spline boundaries
deltas = parameters["deltas"]

# shifted softplus such that network output 0 -> delta = scale
shift = np.log(np.e - 1)
deltas = F.softplus(deltas + shift)

# boundary condition: derivative is equal to affine scale at spline boundaries
scale = torch.sum(parameters["heights"], dim=-1, keepdim=True) / torch.sum(parameters["widths"], dim=-1, keepdim=True)
scale = scale.expand(*scale.shape[:-1], 2)

deltas = torch.cat((deltas, scale), dim=-1).roll(1, dims=-1)

parameters["deltas"] = deltas

return parameters

def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]:
return input_dims

def forward(self, x_or_z: Iterable[torch.Tensor], c: Iterable[torch.Tensor] = None,
rev: bool = False, jac: bool = True) \
-> Tuple[Tuple[torch.Tensor], torch.Tensor]:
if self.conditional:
parameters = self.subnet(torch.cat(c, dim=1).float())
else:
parameters = self.spline_parameters.unsqueeze(0).repeat_interleave(x_or_z[0].shape[0], dim=0)
parameters = self.split_parameters(parameters, self.channels)
parameters = self.constrain_parameters(parameters)

y, jac = self.binned_spline(x=x_or_z[0], parameters=parameters, spline=self.spline, rev=rev)
return (y,), jac

def spline(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
left, right, bottom, top = parameters["left"], parameters["right"], parameters["bottom"], parameters["top"]
deltas_left, deltas_right = parameters["deltas_left"], parameters["deltas_right"]
return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev)

def rational_quadratic_spline(x: torch.Tensor,
left: torch.Tensor,
right: torch.Tensor,
Expand Down Expand Up @@ -116,3 +176,5 @@ def rational_quadratic_spline(x: torch.Tensor,
log_jac = torch.log(numerator) - torch.log(denominator)

return out, log_jac


17 changes: 16 additions & 1 deletion docs/_build/html/FrEIA.framework.html
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,20 @@ <h1>FrEIA.framework package<a class="headerlink" href="#freia-framework-package"
<dd><p>Approximate log Jacobian determinant via finite differences.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FrEIA.framework.GraphINN.plot">
<span class="sig-name descname"><span class="pre">plot</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">path</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">filename</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="reference internal" href="_modules/FrEIA/framework/graph_inn.html#GraphINN.plot"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#FrEIA.framework.GraphINN.plot" title="Permalink to this definition">#</a></dt>
<dd><p>Generates a plot of the GraphINN and stores it as pdf and dot file</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>path</strong> – Directory to store the plots in. Must exist previous to plotting</p></li>
<li><p><strong>filename</strong> – Name of the newly generated plots</p></li>
</ul>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
Expand Down Expand Up @@ -637,6 +651,7 @@ <h1>FrEIA.framework package<a class="headerlink" href="#freia-framework-package"
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#FrEIA.framework.GraphINN.get_module_by_name"><code class="docutils literal notranslate"><span class="pre">GraphINN.get_module_by_name()</span></code></a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#FrEIA.framework.GraphINN.get_node_by_name"><code class="docutils literal notranslate"><span class="pre">GraphINN.get_node_by_name()</span></code></a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#FrEIA.framework.GraphINN.log_jacobian_numerical"><code class="docutils literal notranslate"><span class="pre">GraphINN.log_jacobian_numerical()</span></code></a></li>
<li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#FrEIA.framework.GraphINN.plot"><code class="docutils literal notranslate"><span class="pre">GraphINN.plot()</span></code></a></li>
</ul>
</li>
<li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#FrEIA.framework.InputNode"><code class="docutils literal notranslate"><span class="pre">InputNode</span></code></a><ul class="nav section-nav flex-column">
Expand Down Expand Up @@ -724,7 +739,7 @@ <h1>FrEIA.framework package<a class="headerlink" href="#freia-framework-package"
<div class="footer-items__end">

<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.0.
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.1.
</p></div>

</div>
Expand Down
4 changes: 3 additions & 1 deletion docs/_build/html/FrEIA.modules.html
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@
<li><p>GINCouplingBlock</p></li>
<li><p>AffineCouplingOneSided</p></li>
<li><p>ConditionalAffineTransform</p></li>
<li><p>RationalQuadraticSpline</p></li>
</ul>
<p>Reshaping:</p>
<ul class="simple">
Expand Down Expand Up @@ -388,6 +389,7 @@
<li><p>LearnedElementwiseScaling</p></li>
<li><p>OrthogonalTransform</p></li>
<li><p>HouseholderPerm</p></li>
<li><p>ElementwiseRationalQuadraticSpline</p></li>
</ul>
<p>Fixed (non-learned) transforms:</p>
<ul class="simple">
Expand Down Expand Up @@ -1727,7 +1729,7 @@ <h2>Approximately- or semi-invertible transforms<a class="headerlink" href="#app
<div class="footer-items__end">

<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.0.
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.1.
</p></div>

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/_build/html/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ <h1>All modules for which code is available</h1>
<div class="footer-items__end">

<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.0.
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.1.
</p></div>

</div>
Expand Down
8 changes: 5 additions & 3 deletions docs/_build/html/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,13 @@ <h2 id="P">P</h2>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="FrEIA.framework.html#FrEIA.framework.Node.parse_inputs">parse_inputs() (FrEIA.framework.Node method)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="FrEIA.modules.html#FrEIA.modules.PermuteRandom">PermuteRandom (class in FrEIA.modules)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="FrEIA.modules.html#FrEIA.modules.GaussianMixtureModel.pick_mixture_component">pick_mixture_component() (FrEIA.modules.GaussianMixtureModel static method)</a>
</li>
<li><a href="FrEIA.framework.html#FrEIA.framework.GraphINN.plot">plot() (FrEIA.framework.GraphINN method)</a>
</li>
</ul></td>
</tr></table>
Expand Down Expand Up @@ -729,7 +731,7 @@ <h2 id="S">S</h2>
<div class="footer-items__end">

<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.0.
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.1.
</p></div>

</div>
Expand Down
Loading