From 077383a2990eb3afb2d0206a408af75227c2507a Mon Sep 17 00:00:00 2001 From: lkct Date: Mon, 27 Nov 2023 17:11:48 +0000 Subject: [PATCH] fix notebook to align with previous PRs --- notebooks/digits.ipynb | 79 +++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/notebooks/digits.ipynb b/notebooks/digits.ipynb index 32af6bfa..9a6c9208 100644 --- a/notebooks/digits.ipynb +++ b/notebooks/digits.ipynb @@ -28,7 +28,7 @@ "metadata": {}, "outputs": [], "source": [ - "device = torch.device(\"cpu\") # The device to use, e.g., \"cpu\", \"cuda\", \"cuda:1\"" + "device = torch.device(\"cuda\") # The device to use, e.g., \"cpu\", \"cuda\", \"cuda:1\"" ] }, { @@ -231,30 +231,60 @@ "output_type": "stream", "text": [ "TensorizedPC(\n", - " (input_layer): CategoricalLayer()\n", + " (input_layer): CategoricalLayer(\n", + " (params): ReparamEFCategorical()\n", + " )\n", " (scope_layer): ScopeLayer()\n", " (inner_layers): ModuleList(\n", - " (0-1): 2 x CPLayer()\n", - " (2): MixingLayer()\n", - " (3): CPLayer()\n", - " (4): MixingLayer()\n", - " (5): CPLayer()\n", - " (6): MixingLayer()\n", - " (7): CPLayer()\n", - " (8): MixingLayer()\n", - " (9): CPLayer()\n", - " (10): MixingLayer()\n", - " (11-12): 2 x CPLayer()\n", - " (13): MixingLayer()\n", - " (14-15): 2 x CPLayer()\n", - " (16): MixingLayer()\n", + " (0-1): 2 x CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (2): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (3): CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (4): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (5): CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (6): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (7): CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (8): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (9): CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (10): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (11-12): 2 x CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (13): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", + " (14-15): 2 x CollapsedCPLayer(\n", + " (params_in): ReparamExp()\n", + " )\n", + " (16): SumLayer(\n", + " (params): ReparamExp()\n", + " )\n", " )\n", ")\n" ] } ], "source": [ - "from cirkit.utils.reparams import reparam_exp\n", + "from cirkit.reparams.leaf import ReparamExp\n", "from cirkit.models.tensorized_circuit import TensorizedPC\n", "pc = TensorizedPC.from_region_graph(\n", " region_graph,\n", @@ -265,7 +295,7 @@ " layer_cls=layer_cls,\n", " layer_kwargs=layer_kwargs,\n", " num_classes=1,\n", - " reparam=reparam_exp\n", + " reparam=ReparamExp\n", ")\n", "pc.to(device)\n", "print(pc)" @@ -351,8 +381,9 @@ "for epoch_idx in range(num_epochs):\n", " running_loss = 0.0\n", " for batch, _ in train_dataloader:\n", - " log_score = pc.forward(batch)\n", - " log_pf = pc_pf() # Compute the partition function\n", + " batch = batch.to(device).unsqueeze(dim=-1) # Add a channel dimension\n", + " log_score = pc(batch)\n", + " log_pf = pc_pf(batch) # Compute the partition function\n", " lls = log_score - log_pf # Compute the log-likelihood\n", " loss = -torch.mean(lls) # The loss is the negative average log-likelihood\n", " loss.backward()\n", @@ -385,18 +416,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "Average test LL: -833.413\n", - "Bits per dimension: 1.533623705697752\n" + "Average test LL: -833.414\n", + "Bits per dimension: 1.5336248522113194\n" ] } ], "source": [ "with torch.no_grad():\n", " pc.eval()\n", - " log_pf = pc_pf() # Compute the partition function once for testing\n", + " log_pf = pc_pf(torch.empty((), device=device)) # Compute the partition function once for testing\n", " test_lls = 0.0\n", " for batch, _ in test_dataloader:\n", - " log_score = pc.forward(batch)\n", + " log_score = pc(batch.to(device).unsqueeze(dim=-1))\n", " lls = log_score - log_pf\n", " test_lls += lls.sum().item()\n", " average_ll = test_lls / len(data_test)\n",