Skip to content

Commit

Permalink
preparing exp-sos exps
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Jul 29, 2024
1 parent 840db59 commit fb0fea0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 9 deletions.
39 changes: 39 additions & 0 deletions econfigs/exp-sum-of-squares.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"common": {
"tboard-path": "tboard-runs/uci-data-exp-sos",
"checkpoint-path": "checkpoints/uci-data-exp-sos",
"save-checkpoint": false,
"num-epochs": 1000,
"device": "cuda",
"num-workers": 2,
"patience-threshold": 1e-3,
"early-stop-patience": 30
},
"datasets": ["miniboone", "hepmass"],
"grid": {
"common": {
"region-graph": "rnd-bt",
"region-graph-sd": true,
"batch-size": 512,
"optimizer": "Adam",
"verbose": true,
"learning-rate": [1e-4, 5e-4, 1e-3, 5e-3]
},
"models": {
"hepmass|miniboone": {
"ExpSOS": {
"real": {
"num-units": 192,
"num-input-units": 128,
"complex": false
},
"complex": {
"num-units": 128,
"num-input-units": 128,
"complex": true
}
}
}
}
}
}
8 changes: 6 additions & 2 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, cast

Expand Down Expand Up @@ -305,10 +306,13 @@ def __init__(
)

def layers(self) -> Iterator[TorchLayer]:
return iter(self._circuit.layers)
return itertools.chain(self._circuit.layers, self._mono_circuit.layers)

def sum_layers(self) -> Iterator[TorchSumLayer]:
return filter(lambda l: isinstance(l, TorchSumLayer), self._circuit.layers)
return itertools.chain(
filter(lambda l: isinstance(l, TorchSumLayer), self._circuit.layers),
filter(lambda l: isinstance(l, TorchSumLayer), self._mono_circuit.layers)
)

def log_partition(self) -> Tensor:
return self._int_circuit().real
Expand Down
29 changes: 23 additions & 6 deletions src/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
kde_samples_hmap,
plot_bivariate_discrete_samples_hmap,
)
from models import MPC, PC, SOS
from models import MPC, PC, SOS, ExpSOS
from scripts.logger import Logger
from utilities import (
FLOW_MODELS,
Expand Down Expand Up @@ -452,8 +452,8 @@ def setup_model(
) -> Union[PC, Flow]:
logger.info(f"Building model '{model_name}' ...")

if complex and model_name != "SOS":
raise ValueError("--complex can only be used with SOS circuits")
if complex and model_name not in ["SOS", "ExpSOS"]:
raise ValueError("--complex can only be used with (Exp)SOS circuits")
assert model_name in MODELS
if splines:
raise NotImplementedError()
Expand Down Expand Up @@ -486,7 +486,8 @@ def setup_model(
seed=seed,
)
return model
elif model_name == "SOS":

if model_name == "SOS":
model = SOS(
num_variables,
num_input_units=num_input_units,
Expand All @@ -500,8 +501,24 @@ def setup_model(
seed=seed,
)
return model
else:
raise ValueError(f"Unknown model called {model_name}")

if model_name == "ExpSOS":
model = ExpSOS(
num_variables,
num_input_units=num_input_units,
num_sum_units=num_units,
mono_num_input_units=2,
mono_num_sum_units=2,
input_layer=input_layer,
input_layer_kwargs=input_layer_kwargs,
region_graph=region_graph,
structured_decomposable=structured_decomposable,
complex=complex,
seed=seed,
)
return model

raise ValueError(f"Unknown model called {model_name}")


def setup_flow_model(
Expand Down
2 changes: 1 addition & 1 deletion src/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

PCS_MODELS = ["MPC", "SOS"]
PCS_MODELS = ["MPC", "SOS", "ExpSOS"]

FLOW_MODELS = ["NICE", "MAF", "NSF"]

Expand Down

0 comments on commit fb0fea0

Please sign in to comment.