Skip to content

Commit

Permalink
Merge pull request #153 from april-tools/fix_notebook
Browse files Browse the repository at this point in the history
Fix notebook to align with previous PRs
  • Loading branch information
lkct authored Nov 27, 2023
2 parents 3a96e96 + 077383a commit 7e1aa99
Showing 1 changed file with 55 additions and 24 deletions.
79 changes: 55 additions & 24 deletions notebooks/digits.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cpu\") # The device to use, e.g., \"cpu\", \"cuda\", \"cuda:1\""
"device = torch.device(\"cuda\") # The device to use, e.g., \"cpu\", \"cuda\", \"cuda:1\""
]
},
{
Expand Down Expand Up @@ -231,30 +231,60 @@
"output_type": "stream",
"text": [
"TensorizedPC(\n",
" (input_layer): CategoricalLayer()\n",
" (input_layer): CategoricalLayer(\n",
" (params): ReparamEFCategorical()\n",
" )\n",
" (scope_layer): ScopeLayer()\n",
" (inner_layers): ModuleList(\n",
" (0-1): 2 x CPLayer()\n",
" (2): MixingLayer()\n",
" (3): CPLayer()\n",
" (4): MixingLayer()\n",
" (5): CPLayer()\n",
" (6): MixingLayer()\n",
" (7): CPLayer()\n",
" (8): MixingLayer()\n",
" (9): CPLayer()\n",
" (10): MixingLayer()\n",
" (11-12): 2 x CPLayer()\n",
" (13): MixingLayer()\n",
" (14-15): 2 x CPLayer()\n",
" (16): MixingLayer()\n",
" (0-1): 2 x CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (2): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (3): CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (4): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (5): CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (6): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (7): CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (8): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (9): CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (10): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (11-12): 2 x CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (13): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" (14-15): 2 x CollapsedCPLayer(\n",
" (params_in): ReparamExp()\n",
" )\n",
" (16): SumLayer(\n",
" (params): ReparamExp()\n",
" )\n",
" )\n",
")\n"
]
}
],
"source": [
"from cirkit.utils.reparams import reparam_exp\n",
"from cirkit.reparams.leaf import ReparamExp\n",
"from cirkit.models.tensorized_circuit import TensorizedPC\n",
"pc = TensorizedPC.from_region_graph(\n",
" region_graph,\n",
Expand All @@ -265,7 +295,7 @@
" layer_cls=layer_cls,\n",
" layer_kwargs=layer_kwargs,\n",
" num_classes=1,\n",
" reparam=reparam_exp\n",
" reparam=ReparamExp\n",
")\n",
"pc.to(device)\n",
"print(pc)"
Expand Down Expand Up @@ -351,8 +381,9 @@
"for epoch_idx in range(num_epochs):\n",
" running_loss = 0.0\n",
" for batch, _ in train_dataloader:\n",
" log_score = pc.forward(batch)\n",
" log_pf = pc_pf() # Compute the partition function\n",
" batch = batch.to(device).unsqueeze(dim=-1) # Add a channel dimension\n",
" log_score = pc(batch)\n",
" log_pf = pc_pf(batch) # Compute the partition function\n",
" lls = log_score - log_pf # Compute the log-likelihood\n",
" loss = -torch.mean(lls) # The loss is the negative average log-likelihood\n",
" loss.backward()\n",
Expand Down Expand Up @@ -385,18 +416,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average test LL: -833.413\n",
"Bits per dimension: 1.533623705697752\n"
"Average test LL: -833.414\n",
"Bits per dimension: 1.5336248522113194\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" pc.eval()\n",
" log_pf = pc_pf() # Compute the partition function once for testing\n",
" log_pf = pc_pf(torch.empty((), device=device)) # Compute the partition function once for testing\n",
" test_lls = 0.0\n",
" for batch, _ in test_dataloader:\n",
" log_score = pc.forward(batch)\n",
" log_score = pc(batch.to(device).unsqueeze(dim=-1))\n",
" lls = log_score - log_pf\n",
" test_lls += lls.sum().item()\n",
" average_ll = test_lls / len(data_test)\n",
Expand Down

0 comments on commit 7e1aa99

Please sign in to comment.