diff --git a/notebooks/.ipynb_checkpoints/mminici-renorm_trick_test-0.1-checkpoint.ipynb b/notebooks/.ipynb_checkpoints/mminici-renorm_trick_test-0.1-checkpoint.ipynb new file mode 100644 index 0000000..138ed4f --- /dev/null +++ b/notebooks/.ipynb_checkpoints/mminici-renorm_trick_test-0.1-checkpoint.ipynb @@ -0,0 +1,618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is largely inspired by this [other notebook](https://colab.research.google.com/drive/18EyozusBSgxa5oUBmlzXrp9fEbPyOUoC#scrollTo=_37NFK0iqCFx) originally published by IAML (Italian Association for Machine Learning)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Add the codebase to the path\n", + "import sys\n", + "sys.path.insert(0,'../')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Some useful imports\n", + "import numpy as np\n", + "\n", + "import torch\n", + "from torch import nn, optim\n", + "from torch.utils import data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading cora dataset...\n" + ] + } + ], + "source": [ + "from pygcn.utils import load_data\n", + "symmetric_norm_flag = False\n", + "\n", + "adj, features, labels, idx_train, idx_val, idx_test = load_data(path='../../pygcn/data/cora/', symm_norm=symmetric_norm_flag)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2708, 1433])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2708])\n", + "tensor(6)\n" + ] + } + ], + "source": [ + "print(labels.shape)\n", + "print(torch.max(labels))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2708, 2708])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adj.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Baseline model (without any info on the graph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Feedforward Neural Network with a single hidden layer\n", + "net = nn.Sequential(nn.Linear(1433, 100), nn.ReLU(), nn.Linear(100, 7))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from pygcn.utils import accuracy\n", + "def test(model):\n", + " # Test the model on the test set\n", + " y_pred = model(features[idx_test])\n", + " acc_test = accuracy(y_pred, labels[idx_test])\n", + " print(\"Accuracy:\",\n", + " \"accuracy= {:.4f}\".format(acc_test.item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1460\n" + ] + } + ], + "source": [ + "# Accuracy without training\n", + "test(net)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(net.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [00:01<00:00, 557.19it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(1000)\n", + "\n", + "for epoch in tqdm.trange(1000):\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = net(features[idx_train])\n", + " loss = criterion(outputs, labels[idx_train])\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.4890\n" + ] + } + ], + "source": [ + "test(net)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Graph Convolutional Network" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# Taken (and simplified) from:\n", + "# https://github.com/tkipf/pygcn/blob/master/pygcn/layers.py\n", + "\n", + "import math\n", + "\n", + "import torch\n", + "\n", + "from torch.nn.parameter import Parameter\n", + "from torch.nn.modules.module import Module\n", + "\n", + "\n", + "class GraphConvolution(Module):\n", + " \"\"\"\n", + " Simple GCN layer, similar to https://arxiv.org/abs/1609.02907\n", + " \"\"\"\n", + "\n", + " def __init__(self, in_features, out_features):\n", + " super(GraphConvolution, self).__init__()\n", + " self.weight = Parameter(torch.FloatTensor(in_features, out_features))\n", + " self.bias = Parameter(torch.FloatTensor(out_features))\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " stdv = 1. / math.sqrt(self.weight.size(1))\n", + " self.weight.data.uniform_(-stdv, stdv)\n", + " self.bias.data.uniform_(-stdv, stdv)\n", + "\n", + " def forward(self, input, adj):\n", + " support = torch.mm(input, self.weight) \n", + " output = torch.spmm(adj, support) \n", + " return output + self.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# Taken (and simplified) from:\n", + "# https://github.com/tkipf/pygcn/blob/master/pygcn/models.py\n", + "\n", + "import torch.nn.functional as F\n", + "\n", + "class GCN(nn.Module):\n", + " def __init__(self, nfeat, nhid, nclass):\n", + " super(GCN, self).__init__()\n", + " self.gc1 = GraphConvolution(nfeat, nhid)\n", + " self.gc2 = GraphConvolution(nhid, nclass)\n", + "\n", + " def forward(self, x, adj):\n", + " x = F.relu(self.gc1(x, adj))\n", + " x = self.gc2(x, adj)\n", + " return F.log_softmax(x, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "gcn = GCN(1433, 50, 7)\n", + "optimizer_gcn = optim.Adam(gcn.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def test(model):\n", + " y_pred = model(features, adj) # Using the whole dataset\n", + " acc_test = accuracy(y_pred[idx_test], labels[idx_test]) # Masking on the test set\n", + " print(\"Accuracy:\",\n", + " \"accuracy= {:.4f}\".format(acc_test.item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1390\n" + ] + } + ], + "source": [ + "# Testing without training\n", + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2500/2500 [00:26<00:00, 96.07it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(2500) \n", + "\n", + "for epoch in tqdm.trange(2500): \n", + " \n", + " optimizer_gcn.zero_grad()\n", + " outputs = gcn(features, adj) # Usiamo tutto il dataset\n", + " loss = criterion(outputs[idx_train], labels[idx_train]) # Mascheriamo sulla parte di training\n", + " loss.backward()\n", + " optimizer_gcn.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXSU933v8fd3RhsgIdDGvojF2IJ6wSrCtRtwmmCc2wTnxLe109Ru6hzqxDlJe9ue697eU+c6vfe0yelyk7hJqcuxk+bacbPiNgnGjmPHCw7CAcxiQGDZ7BL7IrTO9/4xj/AgS2gEIz0zz3xe58yZ5/n9nhl9fxrxmeH3PPM85u6IiEh0xcIuQEREhpeCXkQk4hT0IiIRp6AXEYk4Bb2ISMQVhF1Af6qqqnzmzJlhlyEikjM2btx41N2r++vLyqCfOXMmjY2NYZchIpIzzOztgfo0dSMiEnGDBr2ZTTOz581su5ltM7PP97ONmdlXzKzJzLaY2cKUvnvNbHdwuzfTAxARkUtLZ+qmG/hTd3/dzMqAjWa2zt23p2xzOzA3uDUAXwcazKwCeAioBzx47Bp3P5HRUYiIyIAG/UTv7ofc/fVg+QywA5jSZ7MVwDc9aT0wzswmAbcB69z9eBDu64DlGR2BiIhc0pDm6M1sJnAD8FqfrinAvpT1/UHbQO39PfdKM2s0s8bW1tahlCUiIpeQdtCbWSnwPeCP3f10pgtx91XuXu/u9dXV/R4hJCIilyGtoDezQpIh/213/34/mxwApqWsTw3aBmoXEZERks5RNwb8K7DD3f9+gM3WAPcER98sBk65+yFgLbDMzMab2XhgWdCWce1dPax6cQ8vNx0djqcXEclZ6Rx1czPw+8AbZrYpaPsfwHQAd/8G8GPgQ0AT0AZ8Mug7bmZfBDYEj3vY3Y9nrvx3FcVj/PMLe7l5ThU3z6kajh8hIpKTBg16d38JsEG2ceCBAfpWA6svq7ohiMWMW6+u4Zlth+nuSVAQ13fBREQgYt+M/cA1NZxu72bj2zpMX0SkV6SC/pa51RTGjZ+92RJ2KSIiWSNSQV9aXMDiWZU8u+NI2KWIiGSNSAU9wPuvrmFP6zmaj54LuxQRkawQyaAHeE7TNyIiQASDfkblGObWlPLsdk3fiIhABIMe4Lb5E3ntrWMcP9cZdikiIqGLZNAvXzCRhKNP9SIiRDTo508ey9Txo/jJ1kNhlyIiErpIBr2ZsXz+RF5uOsbp9q6wyxERCVUkgx6S0zedPQme19E3IpLnIhv0C6ePp7qsmLXbDoddiohIqCIb9LGYcdv8CTz/ZivtXT1hlyMiEprIBj3A8vmTON/Vwwu7dGlCEclfkQ76hlkVlI8q5KdbNX0jIvkr0kFfGI/xgWsm8NyOI3T1JMIuR0QkFOlcSnC1mbWY2dYB+v/czDYFt61m1mNmFUFfs5m9EfQ1Zrr4dNw2fwKn27tZv/dYGD9eRCR06XyifwxYPlCnu3/Z3a939+uBvwBe6HO5wFuD/vorK/XyvO+qakYVxnX0jYjkrUGD3t1fBNK9zuvdwBNXVFGGlRTGWXJVNc9sO0Ii4WGXIyIy4jI2R29mo0l+8v9eSrMDz5jZRjNbOcjjV5pZo5k1trZm9iiZ2xZMoOVMB5v2n8zo84qI5IJM7oz9MPByn2mbW9x9IXA78ICZvW+gB7v7Knevd/f66urqDJYF7583gYKYafpGRPJSJoP+LvpM27j7geC+BfgBsCiDPy9t5aMLuWl2JWu3HsZd0zcikl8yEvRmVg4sAX6U0jbGzMp6l4FlQL9H7oyE2+ZPpPlYG7uOnA2rBBGRUKRzeOUTwKvAPDPbb2b3mdn9ZnZ/ymYfBZ5x99QLtU4AXjKzzcAvgf90959msvihWFY3ATM0fSMieadgsA3c/e40tnmM5GGYqW17gesut7BMqxlbwg3TxrF222E+91tzwy5HRGTERPqbsX3dNn8i2w6eZt/xtrBLEREZMXkX9ADrdIlBEckjeRX0M6vGMLt6DM/v1MVIRCR/5FXQA9w6r4bX9h6nrbM77FJEREZE/gX91TV09iR4pUknOROR/JB3QV8/czxjiuKavhGRvJF3QV9cEOfmOVX8fGerviUrInkh74IektM3B06eZ3eLviUrItGXl0G/dF7ypGnPv6npGxGJvrwM+knlo7h6Ypnm6UUkL+Rl0AO8/+oaGptPcLq9K+xSRESGVd4G/dJ5NXQnnFeajoZdiojIsMrboL9h+jjGFMV5SUEvIhGXt0FfGI+xeFYlL+1W0ItItOVt0APcPKeK5mNtOpuliERaXgf9b86tAtD0jYhEWl4H/ZyaUiaMLVbQi0ikpXMpwdVm1mJm/V7v1cyWmtkpM9sU3P4qpW+5me00syYzezCThWeCmXHLnGpeaTpKIqHTIYhINKXzif4xYPkg2/zC3a8Pbg8DmFkceAS4HagD7jazuispdjjcMreSE21dbDt4OuxSRESGxaBB7+4vAscv47kXAU3uvtfdO4EngRWX8TzD6uY5yXn6XzS1hlyJiMjwyNQc/U1mttnMfmJm84O2KcC+lG32B239MrOVZtZoZo2trSMXujVlJVw9sYyXNU8vIhGViaB/HZjh7tcBXwV+eDlP4u6r3L3e3eurq6szUFb6bplTxYbmE7R39YzozxURGQlXHPTuftrdzwbLPwYKzawKOABMS9l0atCWdW6eW0Vnd4INzZczQyUikt2uOOjNbKKZWbC8KHjOY8AGYK6Z1ZpZEXAXsOZKf95wWDSzgnjMWL9XlxcUkegpGGwDM3sCWApUmdl+4CGgEMDdvwHcCXzazLqB88Bdnrx0U7eZfRZYC8SB1e6+bVhGcYXGFBdw7dRy1u/VJ3oRiZ5Bg97d7x6k/2vA1wbo+zHw48srbWTdNKuSVS/u5VxHN2OKB/21iIjkjLz+ZmyqxbMq6U44G98+EXYpIiIZpaAP3DhjPAWapxeRCFLQB96dp1fQi0i0KOhTLJ5VyZb9pzjX0R12KSIiGaOgT6F5ehGJIgV9ivqZmqcXkehR0KcYXVTAddPGKehFJFIU9H0snlWheXoRiRQFfR+apxeRqFHQ96Hj6UUkahT0ffTO07+qoBeRiFDQ96OhtoI39p+irVPz9CKS+xT0/WjQPL2IRIiCvh83zhhPPGa8ptMWi0gEKOj7UVpcwIIp5bz2lubpRST3KegHsLi2gs37Tuk6siKS8wYNejNbbWYtZrZ1gP7fM7MtZvaGmb1iZtel9DUH7ZvMrDGThQ+3hlkVdPYkeP0dzdOLSG5L5xP9Y8DyS/S/BSxx918Dvgis6tN/q7tf7+71l1diOOpnVhAzNE8vIjkvnUsJvmhmMy/R/0rK6npg6pWXFb6xJYXUTR6reXoRyXmZnqO/D/hJyroDz5jZRjNbeakHmtlKM2s0s8bW1tYMl3V5Gmor+dU7J+no1jy9iOSujAW9md1KMuj/e0rzLe6+ELgdeMDM3jfQ4919lbvXu3t9dXV1psq6Ig21FXR0J9i871TYpYiIXLaMBL2ZXQs8Cqxw9wtzHe5+ILhvAX4ALMrEzxspi2orMIPXdDoEEclhVxz0ZjYd+D7w++6+K6V9jJmV9S4Dy4B+j9zJVuNGFzFvQhmvvaUdsiKSuwbdGWtmTwBLgSoz2w88BBQCuPs3gL8CKoF/MjOA7uAImwnAD4K2AuD/uftPh2EMw2rxrEq+s2EfXT0JCuP62oGI5J50jrq5e5D+TwGf6qd9L3Ddex+RWxpqK3jslWa27D/FjTPGh12OiMiQ6SPqIBbVVgDoMEsRyVkK+kFUlhYzt6ZUX5wSkZyloE9Dw6wKGpuP092TCLsUEZEhU9CnoaG2knOdPWw7eDrsUkREhkxBn4aGWZqnF5HcpaBPQ01ZCbOqxmieXkRykoI+TQ2zKvhl83F6Eh52KSIiQ6KgT1NDbSVn2rvZcUjz9CKSWxT0aXp3nl7TNyKSWxT0aZpUPorpFaN1gjMRyTkK+iFoqE3O0yc0Ty8iOURBPwQNsyo52dbFrpYzYZciIpI2Bf0QNATnvVm/R9M3IpI7FPRDMK1iNFPGjdIOWRHJKQr6IWqoreCXbx3HXfP0IpIbFPRD1DCrgmPnOmlqORt2KSIiaUkr6M1stZm1mFm/lwK0pK+YWZOZbTGzhSl995rZ7uB2b6YKD0tDbSUA6zV9IyI5It1P9I8Byy/RfzswN7itBL4OYGYVJC892EDywuAPmVlOX6ZpRuVoJowt1vH0IpIz0gp6d38RuNRH2BXANz1pPTDOzCYBtwHr3P24u58A1nHpN4ysZ2Y01FbymubpRSRHZGqOfgqwL2V9f9A2UPt7mNlKM2s0s8bW1tYMlTU8GmZV0Hqmg7eOngu7FBGRQWXNzlh3X+Xu9e5eX11dHXY5l9Q7T6/DLEUkF2Qq6A8A01LWpwZtA7XntNnVY6gq1Ty9iOSGTAX9GuCe4OibxcApdz8ErAWWmdn4YCfssqAtpyXn6Ss0Ty8iOaEgnY3M7AlgKVBlZvtJHklTCODu3wB+DHwIaALagE8GfcfN7IvAhuCpHnb3SMx3NMyq4D/fOMS+4+eZXjk67HJERAaUVtC7+92D9DvwwAB9q4HVQy8tu717PP0xBb2IZLWs2Rmba+bWlDJ+dKGuIysiWU9Bf5liMWNRbQWvvaUdsiKS3RT0V6ChtpL9J85z4OT5sEsRERmQgv4KXLiOrA6zFJEspqC/AldPHMvYkgLN04tIVlPQX4F4zFhUW8nLe47qeHoRyVoK+iu05Koq9p84r/PeiEjWUtBfoSVX1QDwwq7sPhGbiOQvBf0Vml45mtqqMbyooBeRLKWgz4AlV1Xz6t5jtHf1hF2KiMh7KOgzYMlV1bR3JdjQrKNvRCT7KOgzoGFWBUUFMV7YqekbEck+CvoMGF1UQENtBS/uVtCLSPZR0GfI++ZWs+vIWQ7qdAgikmUU9BmyZF7y8oc6+kZEso2CPkPm1pQyubyE595sCbsUEZGLpBX0ZrbczHaaWZOZPdhP/z+Y2abgtsvMTqb09aT0rclk8dnEzPhA3QR+sbuV8506zFJEssegQW9mceAR4HagDrjbzOpSt3H3P3H36939euCrwPdTus/39rn7RzJYe9b5YN0E2rsSvNR0NOxSREQuSOcT/SKgyd33unsn8CSw4hLb3w08kYnick1DbSVlJQWs23447FJERC5IJ+inAPtS1vcHbe9hZjOAWuBnKc0lZtZoZuvN7I7LrjQHFBXEuHVeDc/taKEnobNZikh2yPTO2LuA77p76iT1DHevBz4O/KOZze7vgWa2MnhDaGxtzd0jVz5YN4Fj5zp5/Z0TYZciIgKkF/QHgGkp61ODtv7cRZ9pG3c/ENzvBX4O3NDfA919lbvXu3t9dXV1GmVlp6XzqimMG+u2Hwm7FBERIL2g3wDMNbNaMysiGebvOXrGzK4GxgOvprSNN7PiYLkKuBnYnonCs1VZSSGLZ1XyzLbDuhiJiGSFQYPe3buBzwJrgR3AU+6+zcweNrPUo2juAp70i9PtGqDRzDYDzwN/4+6RDnqAZfMn0nysjV1HzoZdiogIlo2fOuvr672xsTHsMi5by5l2Fv+f53jg1jn86bJ5YZcjInnAzDYG+0PfQ9+MHQY1ZSXcNLuSpzcf1PSNiIROQT9MPnztZJqPtbHt4OmwSxGRPKegHybLF0ykIGY8vflg2KWISJ5T0A+TcaOL+M25VfzHlkOavhGRUCnoh9GHr5vMgZPnef2dk4NvLCIyTBT0w+iDdRMoKojxo00Dfb9MRGT4KeiHUVlJIcvqJrBm80E6unXqYhEJh4J+mP1O/TROtnXx7HZdkEREwqGgH2Y3z6liUnkJTzXuG3xjEZFhoKAfZvGYceeNU/nF7lYOndKFw0Vk5CnoR8CdN04l4fD917VTVkRGnoJ+BMyoHENDbQVPNe4joQuSiMgIU9CPkI83TOftY228uDt3L6oiIrlJQT9Cbl8wieqyYh5/pTnsUkQkzyjoR0hRQYyPL5rOz3e10nz0XNjliEgeUdCPoI83TCduxr+tfzvsUkQkjyjoR9CEsSUsXzCR7zTuo62zO+xyRCRPpBX0ZrbczHaaWZOZPdhP/x+YWauZbQpun0rpu9fMdge3ezNZfC765M0zOdPezXc26AtUIjIyBg16M4sDjwC3A3XA3WZW18+m33H364Pbo8FjK4CHgAZgEfCQmY3PWPU56MYZFfz6zPH8y4t76exOhF2OiOSBdD7RLwKa3H2vu3cCTwIr0nz+24B17n7c3U8A64Dll1dqdHxm6RwOnmrXWS1FZESkE/RTgNR5hv1BW18fM7MtZvZdM5s2xMdiZivNrNHMGltbo32s+dJ51VwzaSzfeGGPvkAlIsMuUztjnwZmuvu1JD+1Pz7UJ3D3Ve5e7+711dXVGSorO5kZn146mz2t53hm++GwyxGRiEsn6A8A01LWpwZtF7j7MXfvCFYfBW5M97H56kMLJlJbNYZ/WLebHn2qF5FhlE7QbwDmmlmtmRUBdwFrUjcws0kpqx8BdgTLa4FlZjY+2Am7LGjLewXxGH/ywavYeeSMLiAuIsNq0KB3927gsyQDegfwlLtvM7OHzewjwWafM7NtZrYZ+BzwB8FjjwNfJPlmsQF4OGgT4Ld/bRJ1k8by9+t26QgcERk25p590wb19fXe2NgYdhkj4vk3W/jkYxv46zsW8InFM8IuR0RylJltdPf6/vr0zdiQLZ1Xza/PHM8/PrubM+1dYZcjIhGkoA+ZmfE//0sdx8518NWfNYVdjohEkII+C1w3bRy/c+M0Vr/0Fk0tZ8MuR0QiRkGfJf58+TxGFcb5X09vIxv3m4hI7lLQZ4mq0mL++INX8YvdR3l6y6GwyxGRCFHQZ5F7b5rBdVPL+cKabRw72zH4A0RE0qCgzyIF8RhfuvM6zrR38YWnt4ddjohEhII+y8ybWMbn3j+Xpzcf5KdbdR4cEblyCvosdP/S2SyYMpYHv7+FQ6fOh12OiOQ4BX0WKozH+MpdN9DZneDzT2yiu0enRxCRy6egz1Kzqkv53x9dwC+bj/OV53aHXY6I5DAFfRb76A1T+djCqXz1+Sae3X4k7HJEJEcp6LPcX9+xgAWTy/n8k79i5+EzYZcjIjlIQZ/lRhXF+Zd76hlTXMB9j2/Q8fUiMmQK+hwwsbyEVffU03Kmgz98vJFzHd1hlyQiOURBnyOunzaOr959A1sPnGLltxrp6O4JuyQRyRFpBb2ZLTeznWbWZGYP9tP/38xsu5ltMbPnzGxGSl+PmW0Kbmv6PlbSd9v8iXzpY9fyctMxPvfEr+jSYZcikoZBg97M4sAjwO1AHXC3mdX12exXQL27Xwt8F/hSSt95d78+uH0EuSIfu3EqD324jrXbjvCZb79Oe5c+2YvIpaXziX4R0OTue929E3gSWJG6gbs/7+5twep6YGpmy5RUn7y5lodXzGfd9iN86vFG2jo1Zy8iA0sn6KcA+1LW9wdtA7kP+EnKeomZNZrZejO7Y6AHmdnKYLvG1tbWNMrKb/fcNJMv33ktr+w5yu89+hqtZ3Q0joj0L6M7Y83sE0A98OWU5hnBBWs/Dvyjmc3u77Huvsrd6929vrq6OpNlRdZ/rZ/GP/3eQnYcOs0dj7zMjkOnwy5JRLJQOkF/AJiWsj41aLuImX0A+EvgI+5+4eOlux8I7vcCPwduuIJ6pY/lCybx1B/dRHciwZ1ff4WfbtVFS0TkYukE/QZgrpnVmlkRcBdw0dEzZnYD8M8kQ74lpX28mRUHy1XAzYBOtJ5h104dx48euIU5NaXc/2+v89CPtmonrYhcMGjQu3s38FlgLbADeMrdt5nZw2bWexTNl4FS4N/7HEZ5DdBoZpuB54G/cXcF/TCYWF7CU/ffxH231PL4q2/z0X96hd1HdMoEEQHLxgtR19fXe2NjY9hl5Kzndhzhz/59M+c6enjg1jl8eulsigr03TiRKDOzjcH+0PfQv/4I+q1rJvDMnyxh2fwJ/MOzu/jwV1+isfl42GWJSEgU9BFVXVbM1z6+kEfvqed0exd3fuNVPvPtjbxzrG3wB4tIpBSEXYAMrw/UTeA35lTyzy/sZdWLe3l2ewufWDyD+5fMomZsSdjlicgI0Bx9Hjl8qp2/e2Yn33t9PwXxGHf/+jT+aMlsJo8bFXZpInKFLjVHr6DPQ81Hz/H1n+/he6/vxyx5LP49N82gfsZ4zCzs8kTkMijopV8HTp7n0V/s5bsb93OmvZtrJo3lE4un89vXTqZ8VGHY5YnIECjo5ZLaOrv54a8O8s1Xm3nz8BmK4jHef3UNd9wwhVuvrqa4IB52iSIyCAW9pMXd2bL/FD/cdICnNx/k6NlOykoKWDqvhg9cU8PSeTX6pC+SpRT0MmTdPQle3nOM/9h8kJ+92cKxc50UxIxFtRX85txqfmN2JQumlBOPaU5fJBtcKuh1eKX0qyAeY8lV1Sy5qpqehLNp30me3XGE53Yc4W9/+iYAZSUFNNRWctPsShZOH0fd5LGa5hHJQvpEL0PWcqad9XuP8+qeo7y65xjNwZewiuIxrpk8lhumjeP6aeP4tanlzKwco0/9IiNAUzcyrA6dOs+md06yad9JfrXvJG/sP8X54OyZxQUx5tSUMm9iGVdPLGPexLHMm1DGhLHFOpRTJIMU9DKiunsS7Dpylu2HTrPz8GnePHyGnYfP0JJyFaxRhXFmVI5mZuUYZlSNprZyDDMqxzCjcjQ1ZcUUxHV2DpGh0By9jKiCeIy6yWOpmzz2ovYT5zp58/AZmlrO0Hysjeaj59jdcoafvdlCZ0/iwnYxgwljS5hUXsLkcaOYPG4Uk8pLmFSevK8qK6aqtEj7A0TSpKCXETN+TBE3zU7uvE3Vk3AOnjzP28faeOd4G4dOnefAyfMcOtnO1gOneGb7ETq7E+95vrKSAqpLi6ksLaKqtPjCrbK0iPGjiygfVUj5qELGjS5k7KhCyooLiGl/geQhBb2ELh4zplWMZlrF6H773Z3j5zo5dKqdw6faOXq2I7h10nq2g6NnOth15Ayv7DnGqfNdA/6cmMHYUYWMC94AyoM3g9LiOGOKChhTXEBpcfJ+zHva4hf6RhfFtX9BckpaQW9my4H/C8SBR939b/r0FwPfBG4EjgG/6+7NQd9fAPcBPcDn3H1txqqXvGBmVJYWU1lazIIp5ZfctrM7wfFznZw638XJtuR96u1kW8ry+S72HW/jbEc35zq6aetM7/KLZlBSEKekMMaowjglhXGKCy9eLymMJbcpil/YtqQwHvTHKC6IU1hgFMZjFMVjFBYk74sKYhTGYxTGjeILy++297bpSCYZikGD3sziwCPAB4H9wAYzW9PnkoD3ASfcfY6Z3QX8LfC7ZlZH8hqz84HJwLNmdpW764KmMiyKCmJMLC9hYvnQT8Hck3DaOrs519FzIfzPdXQnlzu7OdvRc6GtvauH8109tHclaL/ovofT7V20dyU439lDR3ey73xXDz2JzB34EDMuhH9R/N3wL4wb8ZhREEuuF1xYT7alrvfdLtkWS+kz4nGjsHebYD1uyX4zI24QixkxS97iMYJ2IxbjQvul+pLPRdDeu33yf3oXHh9L9lvfvlhyPWaGkXx+u2j93TYjaA/6CfpTHxsLto+adD7RLwKa3H0vgJk9Cazg4ot8rwC+ECx/F/iaJX9bK4An3b0DeMvMmoLnezUz5YtkTjxmlJUUUlYyPKd56Op5902hsydBV3fyvrP73fWuHqezp4fObqcr6OvqSd46gv7U9s6U5e6E05Pw5H1PcJ9Itnf3JPs6unve3SaRfK7U9Qv3wfP1rmfyTSpX9IZ+zMBIvlP0Lg/0ZvKeNt59c0l9I0rdNhZ0GlA5ppin7r8p42NJJ+inAPtS1vcDDQNt4+7dZnYKqAza1/d57JT+foiZrQRWAkyfPj2d2kVySu80TFkOXu/F/eI3goQ7iQQk3OnxPusJx52Udifhl+5Ltvc+F0F7789KPjbRT1/qc7mDB7V68Ji+bU7yOXr7e8fW29bbjzvOu8+RCPro87wXloPtL/zcYPmiWlLbUn9u8Dx48gCD4ZA1O2PdfRWwCpLH0YdcjoikMAumd3REa05K51spB4BpKetTg7Z+tzGzAqCc5E7ZdB4rIiLDKJ2g3wDMNbNaMysiuXN1TZ9t1gD3Bst3Aj/z5Fdu1wB3mVmxmdUCc4FfZqZ0ERFJx6BTN8Gc+2eBtSQPr1zt7tvM7GGg0d3XAP8KfCvY2Xqc5JsBwXZPkdxx2w08oCNuRERGls51IyISAZc6143OHCUiEnEKehGRiFPQi4hEnIJeRCTisnJnrJm1Am9f5sOrgKMZLCcXaMzRl2/jBY15qGa4e3V/HVkZ9FfCzBoH2vMcVRpz9OXbeEFjziRN3YiIRJyCXkQk4qIY9KvCLiAEGnP05dt4QWPOmMjN0YuIyMWi+IleRERSKOhFRCIuMkFvZsvNbKeZNZnZg2HXk0lm1mxmb5jZJjNrDNoqzGydme0O7scH7WZmXwl+D1vMbGG41afHzFabWYuZbU1pG/IYzezeYPvdZnZvfz8rWwww5i+Y2YHgtd5kZh9K6fuLYMw7zey2lPac+ds3s2lm9ryZbTezbWb2+aA9kq/1JcY7sq9z8lJXuX0jefrkPcAsoAjYDNSFXVcGx9cMVPVp+xLwYLD8IPC3wfKHgJ+QvATlYuC1sOtPc4zvAxYCWy93jEAFsDe4Hx8sjw97bEMc8xeAP+tn27rg77oYqA3+3uO59rcPTAIWBstlwK5gbJF8rS8x3hF9naPyif7CBczdvRPovYB5lK0AHg+WHwfuSGn/pietB8aZ2aQwChwKd3+R5LUMUg11jLcB69z9uLufANYBy4e/+sszwJgHsgJ40t073P0toInk331O/e27+yF3fz1YPgPsIHkd6Ui+1pcY70CG5XWOStD3dwHzS/0yc40Dz5jZxuAi6gAT3P1QsHwYmBAsR+l3MdQxRmXsnw2mKVb3TmEQwTGb2UzgBuA18uC17jNeGMHXOSpBH3W3uPtC4HbgATN7X2qnJ//PF+njZPNhjIGvA7OB64FDwN+FW87wMLNS4GmKtJIAAAFwSURBVHvAH7v76dS+KL7W/Yx3RF/nqAR9pC9C7u4HgvsW4Ack/xt3pHdKJrhvCTaP0u9iqGPM+bG7+xF373H3BPAvJF9riNCYzayQZOh9292/HzRH9rXub7wj/TpHJejTuYB5TjKzMWZW1rsMLAO2cvEF2e8FfhQsrwHuCY5WWAycSvkvca4Z6hjXAsvMbHzwX+FlQVvO6LM/5aMkX2tIjvkuMys2s1pgLvBLcuxv38yM5DWmd7j736d0RfK1Hmi8I/46h71XOlM3knvnd5HcM/2XYdeTwXHNIrmHfTOwrXdsQCXwHLAbeBaoCNoNeCT4PbwB1Ic9hjTH+QTJ/8J2kZx/vO9yxgj8IckdWE3AJ8Me12WM+VvBmLYE/5AnpWz/l8GYdwK3p7TnzN8+cAvJaZktwKbg9qGovtaXGO+Ivs46BYKISMRFZepGREQGoKAXEYk4Bb2ISMQp6EVEIk5BLyIScQp6EZGIU9CLiETc/wcnicHPISAdoAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.7930\n" + ] + } + ], + "source": [ + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First model parameters: 144107\n", + "Second model parameters: 72057\n" + ] + } + ], + "source": [ + "# Snippet taken from: https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model\n", + "net_params = sum(p.numel() for p in net.parameters() if p.requires_grad) \n", + "gcn_params = sum(p.numel() for p in gcn.parameters() if p.requires_grad)\n", + "\n", + "print('First model parameters: ', net_params)\n", + "print('Second model parameters: ', gcn_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What if... we make use of the $D^{-\\frac{1}{2}} A D^{-\\frac{1}{2}}$ normalization formula?" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading cora dataset...\n" + ] + } + ], + "source": [ + "from pygcn.utils import load_data\n", + "symmetric_norm_flag = True # change this flag to enable the D^-1/2 A D^-1/2 formula\n", + "\n", + "adj, features, labels, idx_train, idx_val, idx_test = load_data(path='../../pygcn/data/cora/', symm_norm=symmetric_norm_flag)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "gcn = GCN(1433, 50, 7)\n", + "optimizer_gcn = optim.Adam(gcn.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1590\n" + ] + } + ], + "source": [ + "# Testing without training\n", + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5000/5000 [00:47<00:00, 104.21it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(5000) \n", + "\n", + "for epoch in tqdm.trange(5000): \n", + " \n", + " optimizer_gcn.zero_grad()\n", + " outputs = gcn(features, adj) # Usiamo tutto il dataset\n", + " loss = criterion(outputs[idx_train], labels[idx_train]) # Mascheriamo sulla parte di training\n", + " loss.backward()\n", + " optimizer_gcn.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAc40lEQVR4nO3de3Cdd33n8fdHt2NbkmPrEsexHdvJmlJTSEi0uZRMG5bFcTKA2y07dWAh5TKe7ZKl3e5lkmWadMPsLC0zXcqSNnjBA+yWhHvxMqHGhNC0QILlxLk4iRPjJFgmiWXL94usy3f/OD/Jx4pkHVlHOvLzfF4zZ85zfs/l/H4a+aOff8/z/B5FBGZmll011a6AmZlNLQe9mVnGOejNzDLOQW9mlnEOejOzjKurdgVG09bWFsuWLat2NczMzhtbt27dFxHto62bkUG/bNkyOjs7q10NM7PzhqSXx1rnoRszs4xz0JuZZZyD3sws4xz0ZmYZ56A3M8s4B72ZWcY56M3MMi4zQR8RfPbBF/iH57urXRUzsxklM0Evif/98C4eem5vtatiZjajjBv0kpZIekjSM5K2S/qjUbaRpM9K2inpSUlXlqy7VdIL6XVrpRtQqq25wL6jvVP5FWZm551ypkDoB/5jRDwmqRnYKmlzRDxTss1NwIr0ugb4G+AaSS3AXUAHEGnfjRFxoKKtSNqaGhz0ZmYjjNujj4hXIuKxtHwEeBZYNGKzNcBXougRYJ6khcCNwOaI6EnhvhlYXdEWlGhrKrDv6KmpOryZ2XlpQmP0kpYBbwUeHbFqEbC75HNXKhurfLRjr5PUKamzu/vcTqi2NRXY7x69mdkZyg56SU3At4A/jojDla5IRKyPiI6I6GhvH3WmzXG1NjVw4HgffQODFa6dmdn5q6ygl1RPMeT/NiK+Pcome4AlJZ8Xp7KxyqdEW1MBgJ5jHr4xMxtSzlU3Ar4IPBsRfznGZhuBD6arb64FDkXEK8AmYJWk+ZLmA6tS2ZQYCvruIx6+MTMbUs5VN28DPgA8JWlbKvuvwCUAEXEv8ABwM7ATOA58KK3rkfRJYEva7+6I6Klc9c/U3twAwH736M3Mho0b9BHxT4DG2SaAj42xbgOw4ZxqN0GtjcUe/T736M3MhmXmzlgo3jAF+Fp6M7MSmQr6xoZaZtXXOOjNzEpkKuglpWvpPUZvZjYkU0EP0NpUoNs9ejOzYZkL+vamgi+vNDMrkb2gb27wfDdmZiUyF/RtTQV6jvUyMBjVroqZ2YyQyaAfDDhw3L16MzPIaNCDr6U3MxuSwaAvToOw74h79GZmkMWg992xZmZnyF7Qe+jGzOwMmQv6ubPqaKit8U1TZmZJ5oK+OA1Cg8fozcySzAU9FMfp3aM3Mysq5wlTGyTtlfT0GOv/s6Rt6fW0pAFJLWndS5KeSus6K135sbQ1FTwnvZlZUk6P/kvA6rFWRsSnI+KKiLgCuAP4hxFPkXp7Wt8xuaqWr62pwSdjzcyScYM+Ih4Gyn383y3AfZOqUQW0NRXYf+wUg54GwcyscmP0kuZQ7Pl/q6Q4gB9I2ipp3Tj7r5PUKamzu7t7UnVpby4wMBgcPNE3qeOYmWVBJU/Gvhv4yYhhm+sj4krgJuBjkn5rrJ0jYn1EdERER3t7+6Qq4mvpzcxOq2TQr2XEsE1E7Enve4HvAFdX8PvGNBz0PiFrZlaZoJd0AfDbwHdLyholNQ8tA6uAUa/cqbT25uJ8N77E0swM6sbbQNJ9wA1Am6Qu4C6gHiAi7k2b/S7wg4g4VrLrAuA7koa+56sR8feVq/rYTg/d+KYpM7Nxgz4ibiljmy9RvAyztGwXcPm5VmwyLphdT32t/EhBMzMyemesJFobCz4Za2ZGRoMeoK3ZN02ZmUGWg77JPXozM8h60HsGSzOzbAf9/mO9RHgaBDPLtwwHfQN9A8EhT4NgZjmX2aBv97NjzcyADAf90E1T3R6nN7Ocy3zQu0dvZnmX2aBvbSrOd9NzzD16M8u3zAb9/DkNSLDfPXozy7nMBn1tjZg/p4H97tGbWc5lNugBWhobPHRjZrmX6aBvbWxgv6cqNrOcy3bQNzWw/5jH6M0s38YNekkbJO2VNOrToSTdIOmQpG3pdWfJutWSdkjaKen2Sla8HC2NHqM3MyunR/8lYPU42/xjRFyRXncDSKoF7qH4YPCVwC2SVk6mshPV2ljg4PE++gcGp/NrzcxmlHGDPiIeBnrO4dhXAzsjYldEnALuB9acw3HO2fC19Mfdqzez/KrUGP11kp6Q9H1Jb0pli4DdJdt0pbJp09pYvDvWV96YWZ6N+8zYMjwGLI2Io5JuBv4OWDHRg0haB6wDuOSSSypQreIYPUCPr7wxsxybdI8+Ig5HxNG0/ABQL6kN2AMsKdl0cSob6zjrI6IjIjra29snWy2gOFUxwD736M0sxyYd9JIukqS0fHU65n5gC7BC0nJJDcBaYONkv28iTvfofYmlmeXXuEM3ku4DbgDaJHUBdwH1ABFxL/Be4A8l9QMngLVRfKxTv6TbgE1ALbAhIrZPSSvGMG9ovhv36M0sx8YN+oi4ZZz1nwM+N8a6B4AHzq1qk1dbI1o8342Z5Vym74yFdNOUh27MLMcyH/StTZ7YzMzyLftB31jw0I2Z5Vrmg77FM1iaWc5lPuhbmxo4dKKPPs93Y2Y5lf2gT9fSH/DwjZnlVPaDvqk4343H6c0srzIf9MN3xzrozSynMh/0w/Pd+Fp6M8upzAd9i6cqNrOcy3zQz5tdT43wJZZmlluZD/qaGvnZsWaWa5kPevB8N2aWb7kI+tbGgsfozSy3chH0LZ7YzMxyLBdB39bY4MsrzSy3xg16SRsk7ZX09Bjr3y/pSUlPSfqppMtL1r2UyrdJ6qxkxSeipbHA4ZP9nOr3fDdmlj/l9Oi/BKw+y/oXgd+OiDcDnwTWj1j/9oi4IiI6zq2Kk9eSbpo6cNzDN2aWP+MGfUQ8DPScZf1PI+JA+vgIsLhCdauYtjQNgq+lN7M8qvQY/UeA75d8DuAHkrZKWne2HSWtk9QpqbO7u7uilRqa72b/MY/Tm1n+jPtw8HJJejvFoL++pPj6iNgj6UJgs6Tn0v8QXici1pOGfTo6OqJS9YLTM1j6yhszy6OK9OglvQX4ArAmIvYPlUfEnvS+F/gOcHUlvm+iWj10Y2Y5Numgl3QJ8G3gAxHxfEl5o6TmoWVgFTDqlTtT7YLZ9dTWyEM3ZpZL4w7dSLoPuAFok9QF3AXUA0TEvcCdQCvw15IA+tMVNguA76SyOuCrEfH3U9CGcdXUiPlzfNOUmeXTuEEfEbeMs/6jwEdHKd8FXP76PaqjtbGBfR66MbMcysWdsVB8SLh79GaWR7kJes9gaWZ5lZugb2sqeE56M8ul3AR9a2MDR07209s/UO2qmJlNq/wEvW+aMrOcylHQ+6YpM8un/AT98Hw3Dnozy5f8BH0auvGVN2aWNzkKeg/dmFk+5Sbomwt1NNTWsM/z3ZhZzuQm6CXR2tTgHr2Z5U5ugh48DYKZ5VOugr6lseCTsWaWO7kK+jbPYGlmOZSroG9tamD/sV4iKvqkQjOzGa2soJe0QdJeSaM+IUpFn5W0U9KTkq4sWXerpBfS69ZKVfxctDYVONk3yPFTnu/GzPKj3B79l4DVZ1l/E7AivdYBfwMgqYXiE6muofi82LskzT/Xyk6Wnx1rZnlUVtBHxMNAz1k2WQN8JYoeAeZJWgjcCGyOiJ6IOABs5ux/MKbU8E1TvpbezHKkUmP0i4DdJZ+7UtlY5a8jaZ2kTkmd3d3dFarWmVobh6ZBcI/ezPJjxpyMjYj1EdERER3t7e1T8h3u0ZtZHlUq6PcAS0o+L05lY5VXxVCP3pdYmlmeVCroNwIfTFffXAsciohXgE3AKknz00nYVamsKmY31NLYUOuhGzPLlbpyNpJ0H3AD0Capi+KVNPUAEXEv8ABwM7ATOA58KK3rkfRJYEs61N0RcbaTulOuJV1Lb2aWF2UFfUTcMs76AD42xroNwIaJV21qtDYWPN+NmeXKjDkZO13amjwNgpnlS+6CvtUTm5lZzuQv6NNUxYODnu/GzPIhh0FfoH8wOHyyr9pVMTObFvkL+qH5bnxC1sxyIn9B74eEm1nO5C/oh+e78QlZM8uH3AV9W+rR7/PQjZnlRO6Cfv7wnPTu0ZtZPuQu6Otra7hgdr3H6M0sN3IX9HD6WnozszzIZdC3NRbY56EbM8uJXAZ9a1ODr6M3s9zIbdC7R29meZHLoG9rKnDweB+n+gerXRUzsylXVtBLWi1ph6Sdkm4fZf3/lLQtvZ6XdLBk3UDJuo2VrPy5urB5FoB79WaWC+M+eERSLXAP8E6gC9giaWNEPDO0TUT8h5Lt/z3w1pJDnIiIKypX5cm7sLl4d+xrh09y8bzZVa6NmdnUKqdHfzWwMyJ2RcQp4H5gzVm2vwW4rxKVmyoL5hZ79HuPuEdvZtlXTtAvAnaXfO5KZa8jaSmwHPhRSfEsSZ2SHpH0O2N9iaR1abvO7u7uMqp17i6cW+zRO+jNLA8qfTJ2LfDNiBgoKVsaER3A+4DPSLpstB0jYn1EdERER3t7e4WrdabWxgYk6D58ckq/x8xsJign6PcAS0o+L05lo1nLiGGbiNiT3ncBP+bM8fuqqKutoa2pwGuH3aM3s+wrJ+i3ACskLZfUQDHMX3f1jKQ3AvOBn5WUzZdUSMttwNuAZ0buWw0XNhfYe8Q9ejPLvnGvuomIfkm3AZuAWmBDRGyXdDfQGRFDob8WuD8iSh/G+uvA5yUNUvyj8qnSq3WqqRj07tGbWfaNG/QAEfEA8MCIsjtHfP6zUfb7KfDmSdRvylzYPIunf3W42tUwM5tyubwzFmDB3OLEZv0DvjvWzLItt0HfPncWEX5IuJllX26Dfuju2L2+8sbMMs5B7ytvzCzjchv0ngbBzPIit0Hf3lxAglcPuUdvZtmW26Cvr62hvanAK4dOVLsqZmZTKrdBD3DxvNn86qB79GaWbbkO+kXzZvOrg+7Rm1m25TroL543iz0HT3DmrA1mZtmS86CfTW//ID2+acrMMiz3QQ94nN7MMi3XQb9oKOh95Y2ZZViug37hBcWbpnxC1syyLNdB39LYQKGuxkFvZplWVtBLWi1ph6Sdkm4fZf0fSOqWtC29Plqy7lZJL6TXrZWs/GRJSpdYeozezLJr3AePSKoF7gHeCXQBWyRtHOVJUV+LiNtG7NsC3AV0AAFsTfseqEjtK+DiebPpco/ezDKsnB791cDOiNgVEaeA+4E1ZR7/RmBzRPSkcN8MrD63qk6NJS1z2N1zvNrVMDObMuUE/SJgd8nnrlQ20u9JelLSNyUtmeC+SFonqVNSZ3d3dxnVqoxlrXPoOXaKQyf6pu07zcymU6VOxv4/YFlEvIVir/3LEz1ARKyPiI6I6Ghvb69Qtca3tLURgF/ud6/ezLKpnKDfAywp+bw4lQ2LiP0RMTSx+xeAq8rdt9qWtc0B4KX9x6pcEzOzqVFO0G8BVkhaLqkBWAtsLN1A0sKSj+8Bnk3Lm4BVkuZLmg+sSmUzxiUtxaB/2UFvZhk17lU3EdEv6TaKAV0LbIiI7ZLuBjojYiPwcUnvAfqBHuAP0r49kj5J8Y8FwN0R0TMF7ThncxrqWDC3wEseujGzjBo36AEi4gHggRFld5Ys3wHcMca+G4ANk6jjlFva2ugevZllVq7vjB2yvLXRPXozyywHPbC0bQ7dR3o52ttf7aqYmVWcgx64tK0JgF/sPVrlmpiZVZ6DHvi1i5oB2PHakSrXxMys8hz0FC+xLNTV8PyrDnozyx4HPVBbI1YsaHKP3swyyUGfvGFBM8876M0sgxz0yRsWNPPa4V4OHveDws0sWxz0ya8tSCdkPU5vZhnjoE9WXjwXgO2/OlzlmpiZVZaDPlkwdxYL5hZ4sutgtatiZlZRDvoSly+exxNdh6pdDTOzinLQl7h8yTxe3HeMQ8f9tCkzyw4HfYkrlswD4Mk9Hr4xs+xw0Jd48+ILkOCxlx30ZpYdDvoSc2fVs3LhXH62a1+1q2JmVjFlBb2k1ZJ2SNop6fZR1v+JpGckPSnpQUlLS9YNSNqWXhtH7jvT/OZlrTz28kFO9g1UuypmZhUxbtBLqgXuAW4CVgK3SFo5YrPHgY6IeAvwTeAvStadiIgr0us9Far3lPnNy9o4NTDI1pcPVLsqZmYVUU6P/mpgZ0TsiohTwP3AmtINIuKhiBh6RNMjwOLKVnP6/PPlLdTWiJ/s9PCNmWVDOUG/CNhd8rkrlY3lI8D3Sz7PktQp6RFJvzPWTpLWpe06u7u7y6jW1Ggq1HHV0vn86Lm9VauDmVklVfRkrKR/A3QAny4pXhoRHcD7gM9Iumy0fSNifUR0RERHe3t7Jas1YatWLuC5V4/4geFmlgnlBP0eYEnJ58Wp7AyS/iXwCeA9EdE7VB4Re9L7LuDHwFsnUd9pceObLgJg0/ZXq1wTM7PJKyfotwArJC2X1ACsBc64ekbSW4HPUwz5vSXl8yUV0nIb8DbgmUpVfqosaZnDyoVzeeApB72Znf/GDfqI6AduAzYBzwJfj4jtku6WNHQVzaeBJuAbIy6j/HWgU9ITwEPApyJixgc9wL+6chHbdh/0w0jM7LyniKh2HV6no6MjOjs7q1qH/Ud7ufZ/PMgHrl3Gne8eeTWpmdnMImlrOh/6Or4zdgytTQVWvekivv14F8dP9Ve7OmZm58xBfxYfftsyDh7v46uP/rLaVTEzO2cO+rO4amkL113ayvqHd3lKBDM7bznox/Hxd6xg75FevvCPu6pdFTOzc+KgH8d1l7Vy029cxOce2snunuPj72BmNsM46Mvwp+9aSY3Ef/rGE/QPDFa7OmZmE+KgL8PF82Zz95rf4NEXe/jMD1+odnXMzCakrtoVOF+896rF/PzF/XzuoZ1cPG8277vmkmpXycysLA76Cfjvv/tmuo/08om/e4qBwUE+cN2yalfJzGxcHrqZgPraGv76/VfxjjdeyJ9+dzt3ffdpX3ZpZjOeg36CZjfU8vkPdPCR65fz5Z+9zLv/1z+x5aWealfLzGxMDvpzUFsj/vRdK/nKh6/myMl+/vW9P+OjX+7k8V/68YNmNvN4UrNJOnFqgA0/eZF7f/wLjvT2c/niC/i9qxZz45suYsHcWdWunpnlxNkmNXPQV8jR3n6+/VgX//eRl3n+taNIcMWSeVx7aSvXLG/hqqXzaZ5VX+1qmllGOein2QuvHeH7T7/KQzv28lTXIfoHiz/jJS2zeeNFc3njRc0sb2tkScscFs+fzYLmWdTUqMq1NrPz2dmCvqzLKyWtBv4KqAW+EBGfGrG+AHwFuArYD/x+RLyU1t1B8YHhA8DHI2LTObbjvLFiQTMrFjTz8Xes4Pipfh57+SCP//IAz712hB2vHuHBZ19jsOTva32tWHjBbNqaGmhtKhTfGwu0NjUwf04DzbPqaCrU0TSrjuZCffHzrDrqa32KxczGN27QS6oF7gHeCXQBWyRtHPGkqI8AByLin0laC/w58PuSVlJ89OCbgIuBH0p6Q0Tk5prEOQ11XL+ijetXtA2XnewbYM/BE3QdOMHunuN0HTjBrw6eYP+xXnb3HOfxXx6k51jvGX8MRlOoq2F2Qy2z6mop1Ne87r0w4nN9jaitqaG+VtTVpuUaUVsr6mtqqKsVdTWirrYmvYu6mprhshpBjYRUPCE9tFwjpc8gFcuHtq2RqKkpWS4pLz3O6X2LbZOEAAmEYLicVK6houFtVLINnC4bdXv5f1CWH+X06K8GdqaHeyPpfmANZz77dQ3wZ2n5m8DnVPyXtAa4Pz0s/EVJO9PxflaZ6p+fZtXXcll7E5e1N425zeBgcPBEHweOn+JYbz9HT/Zz+GQ/R3v7OXqyj6O9/Rw52c/JvgFO9g1ysn+A3pL3o7397D96avhzb/8AfQPBwGDQNzDIwGAMDynlmU7/DXn9HxeA4T8UI/6YlGzPiGOM9h2vKxu1LqPsO0adyzli+d872naVbcdoRj3eNNSlzB/fmCbaRZhIp6JlTgNf/7fXTfAbxldO0C8Cdpd87gKuGWubiOiXdAhoTeWPjNh30WhfImkdsA7gkks8vUBNjWhpbKClsWHKviOiGPal4d83EPQPDtI/UFzXPzCY3oOguO1gFPcdDBiMYHCwZHnoNTj0ubjtQJTuFwyk9SOPE0Ck7YaXS+rLcFnp8umy0lNOEXHGeji9TaSdTx97xLrS7xk+xvjfOdrP+HVlo243StkoW46+XXnHG23LUY83DXUZ7XhlFk3yZ1re8cYy4a7RBHdonjU1kxXMmCkQImI9sB6KJ2OrXJ1ckER9raivLf4vw8yyqZyzeXuAJSWfF6eyUbeRVAdcQPGkbDn7mpnZFCon6LcAKyQtl9RA8eTqxhHbbARuTcvvBX4Uxf8PbQTWSipIWg6sAH5emaqbmVk5xh26SWPutwGbKF5euSEitku6G+iMiI3AF4H/k0629lD8Y0Da7usUT9z2Ax/L0xU3ZmYzgW+YMjPLgLPdMOU7bszMMs5Bb2aWcQ56M7OMc9CbmWXcjDwZK6kbePkcd28D9lWwOucDtzn78tZecJsnamlEtI+2YkYG/WRI6hzrzHNWuc3Zl7f2gttcSR66MTPLOAe9mVnGZTHo11e7AlXgNmdf3toLbnPFZG6M3szMzpTFHr2ZmZVw0JuZZVxmgl7Sakk7JO2UdHu16zMZkjZI2ivp6ZKyFkmbJb2Q3uenckn6bGr3k5KuLNnn1rT9C5JuHe27ZgpJSyQ9JOkZSdsl/VEqz2y7Jc2S9HNJT6Q2/7dUvlzSo6ltX0vTg5Om+/5aKn9U0rKSY92RyndIurE6LSqPpFpJj0v6Xvqc9fa+JOkpSdskdaay6f29jvQ4t/P5RXH65F8AlwINwBPAymrXaxLt+S3gSuDpkrK/AG5Py7cDf56Wbwa+T/FRltcCj6byFmBXep+fludXu21nafNC4Mq03Aw8D6zMcrtT3ZvScj3waGrL14G1qfxe4A/T8r8D7k3La4GvpeWV6Xe+ACxP/xZqq92+s7T7T4CvAt9Ln7Pe3peAthFl0/p7XfUfQoV+kNcBm0o+3wHcUe16TbJNy0YE/Q5gYVpeCOxIy58Hbhm5HXAL8PmS8jO2m+kv4LvAO/PSbmAO8BjF5zHvA+pS+fDvNsVnQlyXluvSdhr5+1663Ux7UXzK3IPAvwC+l+qf2fam+o0W9NP6e52VoZvRHmA+6kPIz2MLIuKVtPwqsCAtj9X28/Znkv6L/laKPdxMtzsNY2wD9gKbKfZOD0ZEf9qktP7DbUvrDwGtnF9t/gzwX4DB9LmVbLcXio8I/4GkrZLWpbJp/b2eMQ8Ht/JFREjK5HWxkpqAbwF/HBGHJQ2vy2K7o/jEtSskzQO+A7yxylWaMpLeBeyNiK2Sbqh2fabR9RGxR9KFwGZJz5WunI7f66z06PPwEPLXJC0ESO97U/lYbT/vfiaS6imG/N9GxLdTcebbDRARB4GHKA5dzJM01Akrrf9w29L6C4D9nD9tfhvwHkkvAfdTHL75K7LbXgAiYk9630vxj/nVTPPvdVaCvpwHmJ/vSh/AfivFMeyh8g+ms/XXAofSfwk3AaskzU9n9FelshlJxa77F4FnI+IvS1Zltt2S2lNPHkmzKZ6TeJZi4L83bTayzUM/i/cCP4rigO1GYG26SmU5sAL4+fS0onwRcUdELI6IZRT/jf4oIt5PRtsLIKlRUvPQMsXfx6eZ7t/rap+oqOAJj5spXqnxC+AT1a7PJNtyH/AK0EdxLO4jFMcmHwReAH4ItKRtBdyT2v0U0FFynA8DO9PrQ9Vu1zhtvp7iWOaTwLb0ujnL7QbeAjye2vw0cGcqv5RicO0EvgEUUvms9HlnWn9pybE+kX4WO4Cbqt22Mtp+A6evuslse1Pbnkiv7UPZNN2/154Cwcws47IydGNmZmNw0JuZZZyD3sws4xz0ZmYZ56A3M8s4B72ZWcY56M3MMu7/A7pff8eejgM4AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.6160\n" + ] + } + ], + "source": [ + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First model parameters: 144107\n", + "Second model parameters: 72057\n" + ] + } + ], + "source": [ + "# Snippet taken from: https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model\n", + "net_params = sum(p.numel() for p in net.parameters() if p.requires_grad) \n", + "gcn_params = sum(p.numel() for p in gcn.parameters() if p.requires_grad)\n", + "\n", + "print('First model parameters: ', net_params)\n", + "print('Second model parameters: ', gcn_params)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/mminici-renorm_trick_test-0.1.ipynb b/notebooks/mminici-renorm_trick_test-0.1.ipynb new file mode 100644 index 0000000..138ed4f --- /dev/null +++ b/notebooks/mminici-renorm_trick_test-0.1.ipynb @@ -0,0 +1,618 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is largely inspired by this [other notebook](https://colab.research.google.com/drive/18EyozusBSgxa5oUBmlzXrp9fEbPyOUoC#scrollTo=_37NFK0iqCFx) originally published by IAML (Italian Association for Machine Learning)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Add the codebase to the path\n", + "import sys\n", + "sys.path.insert(0,'../')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Some useful imports\n", + "import numpy as np\n", + "\n", + "import torch\n", + "from torch import nn, optim\n", + "from torch.utils import data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading cora dataset...\n" + ] + } + ], + "source": [ + "from pygcn.utils import load_data\n", + "symmetric_norm_flag = False\n", + "\n", + "adj, features, labels, idx_train, idx_val, idx_test = load_data(path='../../pygcn/data/cora/', symm_norm=symmetric_norm_flag)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2708, 1433])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "features.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2708])\n", + "tensor(6)\n" + ] + } + ], + "source": [ + "print(labels.shape)\n", + "print(torch.max(labels))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2708, 2708])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adj.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Baseline model (without any info on the graph)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# Feedforward Neural Network with a single hidden layer\n", + "net = nn.Sequential(nn.Linear(1433, 100), nn.ReLU(), nn.Linear(100, 7))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from pygcn.utils import accuracy\n", + "def test(model):\n", + " # Test the model on the test set\n", + " y_pred = model(features[idx_test])\n", + " acc_test = accuracy(y_pred, labels[idx_test])\n", + " print(\"Accuracy:\",\n", + " \"accuracy= {:.4f}\".format(acc_test.item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1460\n" + ] + } + ], + "source": [ + "# Accuracy without training\n", + "test(net)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(net.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1000/1000 [00:01<00:00, 557.19it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(1000)\n", + "\n", + "for epoch in tqdm.trange(1000):\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = net(features[idx_train])\n", + " loss = criterion(outputs, labels[idx_train])\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.4890\n" + ] + } + ], + "source": [ + "test(net)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Graph Convolutional Network" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# Taken (and simplified) from:\n", + "# https://github.com/tkipf/pygcn/blob/master/pygcn/layers.py\n", + "\n", + "import math\n", + "\n", + "import torch\n", + "\n", + "from torch.nn.parameter import Parameter\n", + "from torch.nn.modules.module import Module\n", + "\n", + "\n", + "class GraphConvolution(Module):\n", + " \"\"\"\n", + " Simple GCN layer, similar to https://arxiv.org/abs/1609.02907\n", + " \"\"\"\n", + "\n", + " def __init__(self, in_features, out_features):\n", + " super(GraphConvolution, self).__init__()\n", + " self.weight = Parameter(torch.FloatTensor(in_features, out_features))\n", + " self.bias = Parameter(torch.FloatTensor(out_features))\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " stdv = 1. / math.sqrt(self.weight.size(1))\n", + " self.weight.data.uniform_(-stdv, stdv)\n", + " self.bias.data.uniform_(-stdv, stdv)\n", + "\n", + " def forward(self, input, adj):\n", + " support = torch.mm(input, self.weight) \n", + " output = torch.spmm(adj, support) \n", + " return output + self.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# Taken (and simplified) from:\n", + "# https://github.com/tkipf/pygcn/blob/master/pygcn/models.py\n", + "\n", + "import torch.nn.functional as F\n", + "\n", + "class GCN(nn.Module):\n", + " def __init__(self, nfeat, nhid, nclass):\n", + " super(GCN, self).__init__()\n", + " self.gc1 = GraphConvolution(nfeat, nhid)\n", + " self.gc2 = GraphConvolution(nhid, nclass)\n", + "\n", + " def forward(self, x, adj):\n", + " x = F.relu(self.gc1(x, adj))\n", + " x = self.gc2(x, adj)\n", + " return F.log_softmax(x, dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "gcn = GCN(1433, 50, 7)\n", + "optimizer_gcn = optim.Adam(gcn.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def test(model):\n", + " y_pred = model(features, adj) # Using the whole dataset\n", + " acc_test = accuracy(y_pred[idx_test], labels[idx_test]) # Masking on the test set\n", + " print(\"Accuracy:\",\n", + " \"accuracy= {:.4f}\".format(acc_test.item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1390\n" + ] + } + ], + "source": [ + "# Testing without training\n", + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 2500/2500 [00:26<00:00, 96.07it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(2500) \n", + "\n", + "for epoch in tqdm.trange(2500): \n", + " \n", + " optimizer_gcn.zero_grad()\n", + " outputs = gcn(features, adj) # Usiamo tutto il dataset\n", + " loss = criterion(outputs[idx_train], labels[idx_train]) # Mascheriamo sulla parte di training\n", + " loss.backward()\n", + " optimizer_gcn.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXSU933v8fd3RhsgIdDGvojF2IJ6wSrCtRtwmmCc2wTnxLe109Ru6hzqxDlJe9ue697eU+c6vfe0yelyk7hJqcuxk+bacbPiNgnGjmPHCw7CAcxiQGDZ7BL7IrTO9/4xj/AgS2gEIz0zz3xe58yZ5/n9nhl9fxrxmeH3PPM85u6IiEh0xcIuQEREhpeCXkQk4hT0IiIRp6AXEYk4Bb2ISMQVhF1Af6qqqnzmzJlhlyEikjM2btx41N2r++vLyqCfOXMmjY2NYZchIpIzzOztgfo0dSMiEnGDBr2ZTTOz581su5ltM7PP97ONmdlXzKzJzLaY2cKUvnvNbHdwuzfTAxARkUtLZ+qmG/hTd3/dzMqAjWa2zt23p2xzOzA3uDUAXwcazKwCeAioBzx47Bp3P5HRUYiIyIAG/UTv7ofc/fVg+QywA5jSZ7MVwDc9aT0wzswmAbcB69z9eBDu64DlGR2BiIhc0pDm6M1sJnAD8FqfrinAvpT1/UHbQO39PfdKM2s0s8bW1tahlCUiIpeQdtCbWSnwPeCP3f10pgtx91XuXu/u9dXV/R4hJCIilyGtoDezQpIh/213/34/mxwApqWsTw3aBmoXEZERks5RNwb8K7DD3f9+gM3WAPcER98sBk65+yFgLbDMzMab2XhgWdCWce1dPax6cQ8vNx0djqcXEclZ6Rx1czPw+8AbZrYpaPsfwHQAd/8G8GPgQ0AT0AZ8Mug7bmZfBDYEj3vY3Y9nrvx3FcVj/PMLe7l5ThU3z6kajh8hIpKTBg16d38JsEG2ceCBAfpWA6svq7ohiMWMW6+u4Zlth+nuSVAQ13fBREQgYt+M/cA1NZxu72bj2zpMX0SkV6SC/pa51RTGjZ+92RJ2KSIiWSNSQV9aXMDiWZU8u+NI2KWIiGSNSAU9wPuvrmFP6zmaj54LuxQRkawQyaAHeE7TNyIiQASDfkblGObWlPLsdk3fiIhABIMe4Lb5E3ntrWMcP9cZdikiIqGLZNAvXzCRhKNP9SIiRDTo508ey9Txo/jJ1kNhlyIiErpIBr2ZsXz+RF5uOsbp9q6wyxERCVUkgx6S0zedPQme19E3IpLnIhv0C6ePp7qsmLXbDoddiohIqCIb9LGYcdv8CTz/ZivtXT1hlyMiEprIBj3A8vmTON/Vwwu7dGlCEclfkQ76hlkVlI8q5KdbNX0jIvkr0kFfGI/xgWsm8NyOI3T1JMIuR0QkFOlcSnC1mbWY2dYB+v/czDYFt61m1mNmFUFfs5m9EfQ1Zrr4dNw2fwKn27tZv/dYGD9eRCR06XyifwxYPlCnu3/Z3a939+uBvwBe6HO5wFuD/vorK/XyvO+qakYVxnX0jYjkrUGD3t1fBNK9zuvdwBNXVFGGlRTGWXJVNc9sO0Ii4WGXIyIy4jI2R29mo0l+8v9eSrMDz5jZRjNbOcjjV5pZo5k1trZm9iiZ2xZMoOVMB5v2n8zo84qI5IJM7oz9MPByn2mbW9x9IXA78ICZvW+gB7v7Knevd/f66urqDJYF7583gYKYafpGRPJSJoP+LvpM27j7geC+BfgBsCiDPy9t5aMLuWl2JWu3HsZd0zcikl8yEvRmVg4sAX6U0jbGzMp6l4FlQL9H7oyE2+ZPpPlYG7uOnA2rBBGRUKRzeOUTwKvAPDPbb2b3mdn9ZnZ/ymYfBZ5x99QLtU4AXjKzzcAvgf90959msvihWFY3ATM0fSMieadgsA3c/e40tnmM5GGYqW17gesut7BMqxlbwg3TxrF222E+91tzwy5HRGTERPqbsX3dNn8i2w6eZt/xtrBLEREZMXkX9ADrdIlBEckjeRX0M6vGMLt6DM/v1MVIRCR/5FXQA9w6r4bX9h6nrbM77FJEREZE/gX91TV09iR4pUknOROR/JB3QV8/czxjiuKavhGRvJF3QV9cEOfmOVX8fGerviUrInkh74IektM3B06eZ3eLviUrItGXl0G/dF7ypGnPv6npGxGJvrwM+knlo7h6Ypnm6UUkL+Rl0AO8/+oaGptPcLq9K+xSRESGVd4G/dJ5NXQnnFeajoZdiojIsMrboL9h+jjGFMV5SUEvIhGXt0FfGI+xeFYlL+1W0ItItOVt0APcPKeK5mNtOpuliERaXgf9b86tAtD0jYhEWl4H/ZyaUiaMLVbQi0ikpXMpwdVm1mJm/V7v1cyWmtkpM9sU3P4qpW+5me00syYzezCThWeCmXHLnGpeaTpKIqHTIYhINKXzif4xYPkg2/zC3a8Pbg8DmFkceAS4HagD7jazuispdjjcMreSE21dbDt4OuxSRESGxaBB7+4vAscv47kXAU3uvtfdO4EngRWX8TzD6uY5yXn6XzS1hlyJiMjwyNQc/U1mttnMfmJm84O2KcC+lG32B239MrOVZtZoZo2trSMXujVlJVw9sYyXNU8vIhGViaB/HZjh7tcBXwV+eDlP4u6r3L3e3eurq6szUFb6bplTxYbmE7R39YzozxURGQlXHPTuftrdzwbLPwYKzawKOABMS9l0atCWdW6eW0Vnd4INzZczQyUikt2uOOjNbKKZWbC8KHjOY8AGYK6Z1ZpZEXAXsOZKf95wWDSzgnjMWL9XlxcUkegpGGwDM3sCWApUmdl+4CGgEMDdvwHcCXzazLqB88Bdnrx0U7eZfRZYC8SB1e6+bVhGcYXGFBdw7dRy1u/VJ3oRiZ5Bg97d7x6k/2vA1wbo+zHw48srbWTdNKuSVS/u5VxHN2OKB/21iIjkjLz+ZmyqxbMq6U44G98+EXYpIiIZpaAP3DhjPAWapxeRCFLQB96dp1fQi0i0KOhTLJ5VyZb9pzjX0R12KSIiGaOgT6F5ehGJIgV9ivqZmqcXkehR0KcYXVTAddPGKehFJFIU9H0snlWheXoRiRQFfR+apxeRqFHQ96Hj6UUkahT0ffTO07+qoBeRiFDQ96OhtoI39p+irVPz9CKS+xT0/WjQPL2IRIiCvh83zhhPPGa8ptMWi0gEKOj7UVpcwIIp5bz2lubpRST3KegHsLi2gs37Tuk6siKS8wYNejNbbWYtZrZ1gP7fM7MtZvaGmb1iZtel9DUH7ZvMrDGThQ+3hlkVdPYkeP0dzdOLSG5L5xP9Y8DyS/S/BSxx918Dvgis6tN/q7tf7+71l1diOOpnVhAzNE8vIjkvnUsJvmhmMy/R/0rK6npg6pWXFb6xJYXUTR6reXoRyXmZnqO/D/hJyroDz5jZRjNbeakHmtlKM2s0s8bW1tYMl3V5Gmor+dU7J+no1jy9iOSujAW9md1KMuj/e0rzLe6+ELgdeMDM3jfQ4919lbvXu3t9dXV1psq6Ig21FXR0J9i871TYpYiIXLaMBL2ZXQs8Cqxw9wtzHe5+ILhvAX4ALMrEzxspi2orMIPXdDoEEclhVxz0ZjYd+D7w++6+K6V9jJmV9S4Dy4B+j9zJVuNGFzFvQhmvvaUdsiKSuwbdGWtmTwBLgSoz2w88BBQCuPs3gL8CKoF/MjOA7uAImwnAD4K2AuD/uftPh2EMw2rxrEq+s2EfXT0JCuP62oGI5J50jrq5e5D+TwGf6qd9L3Ddex+RWxpqK3jslWa27D/FjTPGh12OiMiQ6SPqIBbVVgDoMEsRyVkK+kFUlhYzt6ZUX5wSkZyloE9Dw6wKGpuP092TCLsUEZEhU9CnoaG2knOdPWw7eDrsUkREhkxBn4aGWZqnF5HcpaBPQ01ZCbOqxmieXkRykoI+TQ2zKvhl83F6Eh52KSIiQ6KgT1NDbSVn2rvZcUjz9CKSWxT0aXp3nl7TNyKSWxT0aZpUPorpFaN1gjMRyTkK+iFoqE3O0yc0Ty8iOURBPwQNsyo52dbFrpYzYZciIpI2Bf0QNATnvVm/R9M3IpI7FPRDMK1iNFPGjdIOWRHJKQr6IWqoreCXbx3HXfP0IpIbFPRD1DCrgmPnOmlqORt2KSIiaUkr6M1stZm1mFm/lwK0pK+YWZOZbTGzhSl995rZ7uB2b6YKD0tDbSUA6zV9IyI5It1P9I8Byy/RfzswN7itBL4OYGYVJC892EDywuAPmVlOX6ZpRuVoJowt1vH0IpIz0gp6d38RuNRH2BXANz1pPTDOzCYBtwHr3P24u58A1nHpN4ysZ2Y01FbymubpRSRHZGqOfgqwL2V9f9A2UPt7mNlKM2s0s8bW1tYMlTU8GmZV0Hqmg7eOngu7FBGRQWXNzlh3X+Xu9e5eX11dHXY5l9Q7T6/DLEUkF2Qq6A8A01LWpwZtA7XntNnVY6gq1Ty9iOSGTAX9GuCe4OibxcApdz8ErAWWmdn4YCfssqAtpyXn6Ss0Ty8iOaEgnY3M7AlgKVBlZvtJHklTCODu3wB+DHwIaALagE8GfcfN7IvAhuCpHnb3SMx3NMyq4D/fOMS+4+eZXjk67HJERAaUVtC7+92D9DvwwAB9q4HVQy8tu717PP0xBb2IZLWs2Rmba+bWlDJ+dKGuIysiWU9Bf5liMWNRbQWvvaUdsiKS3RT0V6ChtpL9J85z4OT5sEsRERmQgv4KXLiOrA6zFJEspqC/AldPHMvYkgLN04tIVlPQX4F4zFhUW8nLe47qeHoRyVoK+iu05Koq9p84r/PeiEjWUtBfoSVX1QDwwq7sPhGbiOQvBf0Vml45mtqqMbyooBeRLKWgz4AlV1Xz6t5jtHf1hF2KiMh7KOgzYMlV1bR3JdjQrKNvRCT7KOgzoGFWBUUFMV7YqekbEck+CvoMGF1UQENtBS/uVtCLSPZR0GfI++ZWs+vIWQ7qdAgikmUU9BmyZF7y8oc6+kZEso2CPkPm1pQyubyE595sCbsUEZGLpBX0ZrbczHaaWZOZPdhP/z+Y2abgtsvMTqb09aT0rclk8dnEzPhA3QR+sbuV8506zFJEssegQW9mceAR4HagDrjbzOpSt3H3P3H36939euCrwPdTus/39rn7RzJYe9b5YN0E2rsSvNR0NOxSREQuSOcT/SKgyd33unsn8CSw4hLb3w08kYnick1DbSVlJQWs23447FJERC5IJ+inAPtS1vcHbe9hZjOAWuBnKc0lZtZoZuvN7I7LrjQHFBXEuHVeDc/taKEnobNZikh2yPTO2LuA77p76iT1DHevBz4O/KOZze7vgWa2MnhDaGxtzd0jVz5YN4Fj5zp5/Z0TYZciIgKkF/QHgGkp61ODtv7cRZ9pG3c/ENzvBX4O3NDfA919lbvXu3t9dXV1GmVlp6XzqimMG+u2Hwm7FBERIL2g3wDMNbNaMysiGebvOXrGzK4GxgOvprSNN7PiYLkKuBnYnonCs1VZSSGLZ1XyzLbDuhiJiGSFQYPe3buBzwJrgR3AU+6+zcweNrPUo2juAp70i9PtGqDRzDYDzwN/4+6RDnqAZfMn0nysjV1HzoZdiogIlo2fOuvr672xsTHsMi5by5l2Fv+f53jg1jn86bJ5YZcjInnAzDYG+0PfQ9+MHQY1ZSXcNLuSpzcf1PSNiIROQT9MPnztZJqPtbHt4OmwSxGRPKegHybLF0ykIGY8vflg2KWISJ5T0A+TcaOL+M25VfzHlkOavhGRUCnoh9GHr5vMgZPnef2dk4NvLCIyTBT0w+iDdRMoKojxo00Dfb9MRGT4KeiHUVlJIcvqJrBm80E6unXqYhEJh4J+mP1O/TROtnXx7HZdkEREwqGgH2Y3z6liUnkJTzXuG3xjEZFhoKAfZvGYceeNU/nF7lYOndKFw0Vk5CnoR8CdN04l4fD917VTVkRGnoJ+BMyoHENDbQVPNe4joQuSiMgIU9CPkI83TOftY228uDt3L6oiIrlJQT9Cbl8wieqyYh5/pTnsUkQkzyjoR0hRQYyPL5rOz3e10nz0XNjliEgeUdCPoI83TCduxr+tfzvsUkQkjyjoR9CEsSUsXzCR7zTuo62zO+xyRCRPpBX0ZrbczHaaWZOZPdhP/x+YWauZbQpun0rpu9fMdge3ezNZfC765M0zOdPezXc26AtUIjIyBg16M4sDjwC3A3XA3WZW18+m33H364Pbo8FjK4CHgAZgEfCQmY3PWPU56MYZFfz6zPH8y4t76exOhF2OiOSBdD7RLwKa3H2vu3cCTwIr0nz+24B17n7c3U8A64Dll1dqdHxm6RwOnmrXWS1FZESkE/RTgNR5hv1BW18fM7MtZvZdM5s2xMdiZivNrNHMGltbo32s+dJ51VwzaSzfeGGPvkAlIsMuUztjnwZmuvu1JD+1Pz7UJ3D3Ve5e7+711dXVGSorO5kZn146mz2t53hm++GwyxGRiEsn6A8A01LWpwZtF7j7MXfvCFYfBW5M97H56kMLJlJbNYZ/WLebHn2qF5FhlE7QbwDmmlmtmRUBdwFrUjcws0kpqx8BdgTLa4FlZjY+2Am7LGjLewXxGH/ywavYeeSMLiAuIsNq0KB3927gsyQDegfwlLtvM7OHzewjwWafM7NtZrYZ+BzwB8FjjwNfJPlmsQF4OGgT4Ld/bRJ1k8by9+t26QgcERk25p590wb19fXe2NgYdhkj4vk3W/jkYxv46zsW8InFM8IuR0RylJltdPf6/vr0zdiQLZ1Xza/PHM8/PrubM+1dYZcjIhGkoA+ZmfE//0sdx8518NWfNYVdjohEkII+C1w3bRy/c+M0Vr/0Fk0tZ8MuR0QiRkGfJf58+TxGFcb5X09vIxv3m4hI7lLQZ4mq0mL++INX8YvdR3l6y6GwyxGRCFHQZ5F7b5rBdVPL+cKabRw72zH4A0RE0qCgzyIF8RhfuvM6zrR38YWnt4ddjohEhII+y8ybWMbn3j+Xpzcf5KdbdR4cEblyCvosdP/S2SyYMpYHv7+FQ6fOh12OiOQ4BX0WKozH+MpdN9DZneDzT2yiu0enRxCRy6egz1Kzqkv53x9dwC+bj/OV53aHXY6I5DAFfRb76A1T+djCqXz1+Sae3X4k7HJEJEcp6LPcX9+xgAWTy/n8k79i5+EzYZcjIjlIQZ/lRhXF+Zd76hlTXMB9j2/Q8fUiMmQK+hwwsbyEVffU03Kmgz98vJFzHd1hlyQiOURBnyOunzaOr959A1sPnGLltxrp6O4JuyQRyRFpBb2ZLTeznWbWZGYP9tP/38xsu5ltMbPnzGxGSl+PmW0Kbmv6PlbSd9v8iXzpY9fyctMxPvfEr+jSYZcikoZBg97M4sAjwO1AHXC3mdX12exXQL27Xwt8F/hSSt95d78+uH0EuSIfu3EqD324jrXbjvCZb79Oe5c+2YvIpaXziX4R0OTue929E3gSWJG6gbs/7+5twep6YGpmy5RUn7y5lodXzGfd9iN86vFG2jo1Zy8iA0sn6KcA+1LW9wdtA7kP+EnKeomZNZrZejO7Y6AHmdnKYLvG1tbWNMrKb/fcNJMv33ktr+w5yu89+hqtZ3Q0joj0L6M7Y83sE0A98OWU5hnBBWs/Dvyjmc3u77Huvsrd6929vrq6OpNlRdZ/rZ/GP/3eQnYcOs0dj7zMjkOnwy5JRLJQOkF/AJiWsj41aLuImX0A+EvgI+5+4eOlux8I7vcCPwduuIJ6pY/lCybx1B/dRHciwZ1ff4WfbtVFS0TkYukE/QZgrpnVmlkRcBdw0dEzZnYD8M8kQ74lpX28mRUHy1XAzYBOtJ5h104dx48euIU5NaXc/2+v89CPtmonrYhcMGjQu3s38FlgLbADeMrdt5nZw2bWexTNl4FS4N/7HEZ5DdBoZpuB54G/cXcF/TCYWF7CU/ffxH231PL4q2/z0X96hd1HdMoEEQHLxgtR19fXe2NjY9hl5Kzndhzhz/59M+c6enjg1jl8eulsigr03TiRKDOzjcH+0PfQv/4I+q1rJvDMnyxh2fwJ/MOzu/jwV1+isfl42GWJSEgU9BFVXVbM1z6+kEfvqed0exd3fuNVPvPtjbxzrG3wB4tIpBSEXYAMrw/UTeA35lTyzy/sZdWLe3l2ewufWDyD+5fMomZsSdjlicgI0Bx9Hjl8qp2/e2Yn33t9PwXxGHf/+jT+aMlsJo8bFXZpInKFLjVHr6DPQ81Hz/H1n+/he6/vxyx5LP49N82gfsZ4zCzs8kTkMijopV8HTp7n0V/s5bsb93OmvZtrJo3lE4un89vXTqZ8VGHY5YnIECjo5ZLaOrv54a8O8s1Xm3nz8BmK4jHef3UNd9wwhVuvrqa4IB52iSIyCAW9pMXd2bL/FD/cdICnNx/k6NlOykoKWDqvhg9cU8PSeTX6pC+SpRT0MmTdPQle3nOM/9h8kJ+92cKxc50UxIxFtRX85txqfmN2JQumlBOPaU5fJBtcKuh1eKX0qyAeY8lV1Sy5qpqehLNp30me3XGE53Yc4W9/+iYAZSUFNNRWctPsShZOH0fd5LGa5hHJQvpEL0PWcqad9XuP8+qeo7y65xjNwZewiuIxrpk8lhumjeP6aeP4tanlzKwco0/9IiNAUzcyrA6dOs+md06yad9JfrXvJG/sP8X54OyZxQUx5tSUMm9iGVdPLGPexLHMm1DGhLHFOpRTJIMU9DKiunsS7Dpylu2HTrPz8GnePHyGnYfP0JJyFaxRhXFmVI5mZuUYZlSNprZyDDMqxzCjcjQ1ZcUUxHV2DpGh0By9jKiCeIy6yWOpmzz2ovYT5zp58/AZmlrO0Hysjeaj59jdcoafvdlCZ0/iwnYxgwljS5hUXsLkcaOYPG4Uk8pLmFSevK8qK6aqtEj7A0TSpKCXETN+TBE3zU7uvE3Vk3AOnjzP28faeOd4G4dOnefAyfMcOtnO1gOneGb7ETq7E+95vrKSAqpLi6ksLaKqtPjCrbK0iPGjiygfVUj5qELGjS5k7KhCyooLiGl/geQhBb2ELh4zplWMZlrF6H773Z3j5zo5dKqdw6faOXq2I7h10nq2g6NnOth15Ayv7DnGqfNdA/6cmMHYUYWMC94AyoM3g9LiOGOKChhTXEBpcfJ+zHva4hf6RhfFtX9BckpaQW9my4H/C8SBR939b/r0FwPfBG4EjgG/6+7NQd9fAPcBPcDn3H1txqqXvGBmVJYWU1lazIIp5ZfctrM7wfFznZw638XJtuR96u1kW8ry+S72HW/jbEc35zq6aetM7/KLZlBSEKekMMaowjglhXGKCy9eLymMJbcpil/YtqQwHvTHKC6IU1hgFMZjFMVjFBYk74sKYhTGYxTGjeILy++297bpSCYZikGD3sziwCPAB4H9wAYzW9PnkoD3ASfcfY6Z3QX8LfC7ZlZH8hqz84HJwLNmdpW764KmMiyKCmJMLC9hYvnQT8Hck3DaOrs519FzIfzPdXQnlzu7OdvRc6GtvauH8109tHclaL/ovofT7V20dyU439lDR3ey73xXDz2JzB34EDMuhH9R/N3wL4wb8ZhREEuuF1xYT7alrvfdLtkWS+kz4nGjsHebYD1uyX4zI24QixkxS97iMYJ2IxbjQvul+pLPRdDeu33yf3oXHh9L9lvfvlhyPWaGkXx+u2j93TYjaA/6CfpTHxsLto+adD7RLwKa3H0vgJk9Cazg4ot8rwC+ECx/F/iaJX9bK4An3b0DeMvMmoLnezUz5YtkTjxmlJUUUlYyPKd56Op5902hsydBV3fyvrP73fWuHqezp4fObqcr6OvqSd46gv7U9s6U5e6E05Pw5H1PcJ9Itnf3JPs6unve3SaRfK7U9Qv3wfP1rmfyTSpX9IZ+zMBIvlP0Lg/0ZvKeNt59c0l9I0rdNhZ0GlA5ppin7r8p42NJJ+inAPtS1vcDDQNt4+7dZnYKqAza1/d57JT+foiZrQRWAkyfPj2d2kVySu80TFkOXu/F/eI3goQ7iQQk3OnxPusJx52Udifhl+5Ltvc+F0F7789KPjbRT1/qc7mDB7V68Ji+bU7yOXr7e8fW29bbjzvOu8+RCPro87wXloPtL/zcYPmiWlLbUn9u8Dx48gCD4ZA1O2PdfRWwCpLH0YdcjoikMAumd3REa05K51spB4BpKetTg7Z+tzGzAqCc5E7ZdB4rIiLDKJ2g3wDMNbNaMysiuXN1TZ9t1gD3Bst3Aj/z5Fdu1wB3mVmxmdUCc4FfZqZ0ERFJx6BTN8Gc+2eBtSQPr1zt7tvM7GGg0d3XAP8KfCvY2Xqc5JsBwXZPkdxx2w08oCNuRERGls51IyISAZc6143OHCUiEnEKehGRiFPQi4hEnIJeRCTisnJnrJm1Am9f5sOrgKMZLCcXaMzRl2/jBY15qGa4e3V/HVkZ9FfCzBoH2vMcVRpz9OXbeEFjziRN3YiIRJyCXkQk4qIY9KvCLiAEGnP05dt4QWPOmMjN0YuIyMWi+IleRERSKOhFRCIuMkFvZsvNbKeZNZnZg2HXk0lm1mxmb5jZJjNrDNoqzGydme0O7scH7WZmXwl+D1vMbGG41afHzFabWYuZbU1pG/IYzezeYPvdZnZvfz8rWwww5i+Y2YHgtd5kZh9K6fuLYMw7zey2lPac+ds3s2lm9ryZbTezbWb2+aA9kq/1JcY7sq9z8lJXuX0jefrkPcAsoAjYDNSFXVcGx9cMVPVp+xLwYLD8IPC3wfKHgJ+QvATlYuC1sOtPc4zvAxYCWy93jEAFsDe4Hx8sjw97bEMc8xeAP+tn27rg77oYqA3+3uO59rcPTAIWBstlwK5gbJF8rS8x3hF9naPyif7CBczdvRPovYB5lK0AHg+WHwfuSGn/pietB8aZ2aQwChwKd3+R5LUMUg11jLcB69z9uLufANYBy4e/+sszwJgHsgJ40t073P0toInk331O/e27+yF3fz1YPgPsIHkd6Ui+1pcY70CG5XWOStD3dwHzS/0yc40Dz5jZxuAi6gAT3P1QsHwYmBAsR+l3MdQxRmXsnw2mKVb3TmEQwTGb2UzgBuA18uC17jNeGMHXOSpBH3W3uPtC4HbgATN7X2qnJ//PF+njZPNhjIGvA7OB64FDwN+FW87wMLNS4GmKtJIAAAFwSURBVHvAH7v76dS+KL7W/Yx3RF/nqAR9pC9C7u4HgvsW4Ack/xt3pHdKJrhvCTaP0u9iqGPM+bG7+xF373H3BPAvJF9riNCYzayQZOh9292/HzRH9rXub7wj/TpHJejTuYB5TjKzMWZW1rsMLAO2cvEF2e8FfhQsrwHuCY5WWAycSvkvca4Z6hjXAsvMbHzwX+FlQVvO6LM/5aMkX2tIjvkuMys2s1pgLvBLcuxv38yM5DWmd7j736d0RfK1Hmi8I/46h71XOlM3knvnd5HcM/2XYdeTwXHNIrmHfTOwrXdsQCXwHLAbeBaoCNoNeCT4PbwB1Ic9hjTH+QTJ/8J2kZx/vO9yxgj8IckdWE3AJ8Me12WM+VvBmLYE/5AnpWz/l8GYdwK3p7TnzN8+cAvJaZktwKbg9qGovtaXGO+Ivs46BYKISMRFZepGREQGoKAXEYk4Bb2ISMQp6EVEIk5BLyIScQp6EZGIU9CLiETc/wcnicHPISAdoAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.7930\n" + ] + } + ], + "source": [ + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First model parameters: 144107\n", + "Second model parameters: 72057\n" + ] + } + ], + "source": [ + "# Snippet taken from: https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model\n", + "net_params = sum(p.numel() for p in net.parameters() if p.requires_grad) \n", + "gcn_params = sum(p.numel() for p in gcn.parameters() if p.requires_grad)\n", + "\n", + "print('First model parameters: ', net_params)\n", + "print('Second model parameters: ', gcn_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What if... we make use of the $D^{-\\frac{1}{2}} A D^{-\\frac{1}{2}}$ normalization formula?" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading cora dataset...\n" + ] + } + ], + "source": [ + "from pygcn.utils import load_data\n", + "symmetric_norm_flag = True # change this flag to enable the D^-1/2 A D^-1/2 formula\n", + "\n", + "adj, features, labels, idx_train, idx_val, idx_test = load_data(path='../../pygcn/data/cora/', symm_norm=symmetric_norm_flag)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "gcn = GCN(1433, 50, 7)\n", + "optimizer_gcn = optim.Adam(gcn.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.1590\n" + ] + } + ], + "source": [ + "# Testing without training\n", + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 5000/5000 [00:47<00:00, 104.21it/s]\n" + ] + } + ], + "source": [ + "import tqdm\n", + "loss_history = np.zeros(5000) \n", + "\n", + "for epoch in tqdm.trange(5000): \n", + " \n", + " optimizer_gcn.zero_grad()\n", + " outputs = gcn(features, adj) # Usiamo tutto il dataset\n", + " loss = criterion(outputs[idx_train], labels[idx_train]) # Mascheriamo sulla parte di training\n", + " loss.backward()\n", + " optimizer_gcn.step()\n", + "\n", + " loss_history[epoch] = loss.detach().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAc40lEQVR4nO3de3Cdd33n8fdHt2NbkmPrEsexHdvJmlJTSEi0uZRMG5bFcTKA2y07dWAh5TKe7ZKl3e5lkmWadMPsLC0zXcqSNnjBA+yWhHvxMqHGhNC0QILlxLk4iRPjJFgmiWXL94usy3f/OD/Jx4pkHVlHOvLzfF4zZ85zfs/l/H4a+aOff8/z/B5FBGZmll011a6AmZlNLQe9mVnGOejNzDLOQW9mlnEOejOzjKurdgVG09bWFsuWLat2NczMzhtbt27dFxHto62bkUG/bNkyOjs7q10NM7PzhqSXx1rnoRszs4xz0JuZZZyD3sws4xz0ZmYZ56A3M8s4B72ZWcY56M3MMi4zQR8RfPbBF/iH57urXRUzsxklM0Evif/98C4eem5vtatiZjajjBv0kpZIekjSM5K2S/qjUbaRpM9K2inpSUlXlqy7VdIL6XVrpRtQqq25wL6jvVP5FWZm551ypkDoB/5jRDwmqRnYKmlzRDxTss1NwIr0ugb4G+AaSS3AXUAHEGnfjRFxoKKtSNqaGhz0ZmYjjNujj4hXIuKxtHwEeBZYNGKzNcBXougRYJ6khcCNwOaI6EnhvhlYXdEWlGhrKrDv6KmpOryZ2XlpQmP0kpYBbwUeHbFqEbC75HNXKhurfLRjr5PUKamzu/vcTqi2NRXY7x69mdkZyg56SU3At4A/jojDla5IRKyPiI6I6GhvH3WmzXG1NjVw4HgffQODFa6dmdn5q6ygl1RPMeT/NiK+Pcome4AlJZ8Xp7KxyqdEW1MBgJ5jHr4xMxtSzlU3Ar4IPBsRfznGZhuBD6arb64FDkXEK8AmYJWk+ZLmA6tS2ZQYCvruIx6+MTMbUs5VN28DPgA8JWlbKvuvwCUAEXEv8ABwM7ATOA58KK3rkfRJYEva7+6I6Klc9c/U3twAwH736M3Mho0b9BHxT4DG2SaAj42xbgOw4ZxqN0GtjcUe/T736M3MhmXmzlgo3jAF+Fp6M7MSmQr6xoZaZtXXOOjNzEpkKuglpWvpPUZvZjYkU0EP0NpUoNs9ejOzYZkL+vamgi+vNDMrkb2gb27wfDdmZiUyF/RtTQV6jvUyMBjVroqZ2YyQyaAfDDhw3L16MzPIaNCDr6U3MxuSwaAvToOw74h79GZmkMWg992xZmZnyF7Qe+jGzOwMmQv6ubPqaKit8U1TZmZJ5oK+OA1Cg8fozcySzAU9FMfp3aM3Mysq5wlTGyTtlfT0GOv/s6Rt6fW0pAFJLWndS5KeSus6K135sbQ1FTwnvZlZUk6P/kvA6rFWRsSnI+KKiLgCuAP4hxFPkXp7Wt8xuaqWr62pwSdjzcyScYM+Ih4Gyn383y3AfZOqUQW0NRXYf+wUg54GwcyscmP0kuZQ7Pl/q6Q4gB9I2ipp3Tj7r5PUKamzu7t7UnVpby4wMBgcPNE3qeOYmWVBJU/Gvhv4yYhhm+sj4krgJuBjkn5rrJ0jYn1EdERER3t7+6Qq4mvpzcxOq2TQr2XEsE1E7Enve4HvAFdX8PvGNBz0PiFrZlaZoJd0AfDbwHdLyholNQ8tA6uAUa/cqbT25uJ8N77E0swM6sbbQNJ9wA1Am6Qu4C6gHiAi7k2b/S7wg4g4VrLrAuA7koa+56sR8feVq/rYTg/d+KYpM7Nxgz4ibiljmy9RvAyztGwXcPm5VmwyLphdT32t/EhBMzMyemesJFobCz4Za2ZGRoMeoK3ZN02ZmUGWg77JPXozM8h60HsGSzOzbAf9/mO9RHgaBDPLtwwHfQN9A8EhT4NgZjmX2aBv97NjzcyADAf90E1T3R6nN7Ocy3zQu0dvZnmX2aBvbSrOd9NzzD16M8u3zAb9/DkNSLDfPXozy7nMBn1tjZg/p4H97tGbWc5lNugBWhobPHRjZrmX6aBvbWxgv6cqNrOcy3bQNzWw/5jH6M0s38YNekkbJO2VNOrToSTdIOmQpG3pdWfJutWSdkjaKen2Sla8HC2NHqM3MyunR/8lYPU42/xjRFyRXncDSKoF7qH4YPCVwC2SVk6mshPV2ljg4PE++gcGp/NrzcxmlHGDPiIeBnrO4dhXAzsjYldEnALuB9acw3HO2fC19Mfdqzez/KrUGP11kp6Q9H1Jb0pli4DdJdt0pbJp09pYvDvWV96YWZ6N+8zYMjwGLI2Io5JuBv4OWDHRg0haB6wDuOSSSypQreIYPUCPr7wxsxybdI8+Ig5HxNG0/ABQL6kN2AMsKdl0cSob6zjrI6IjIjra29snWy2gOFUxwD736M0sxyYd9JIukqS0fHU65n5gC7BC0nJJDcBaYONkv28iTvfofYmlmeXXuEM3ku4DbgDaJHUBdwH1ABFxL/Be4A8l9QMngLVRfKxTv6TbgE1ALbAhIrZPSSvGMG9ovhv36M0sx8YN+oi4ZZz1nwM+N8a6B4AHzq1qk1dbI1o8342Z5Vym74yFdNOUh27MLMcyH/StTZ7YzMzyLftB31jw0I2Z5Vrmg77FM1iaWc5lPuhbmxo4dKKPPs93Y2Y5lf2gT9fSH/DwjZnlVPaDvqk4343H6c0srzIf9MN3xzrozSynMh/0w/Pd+Fp6M8upzAd9i6cqNrOcy3zQz5tdT43wJZZmlluZD/qaGvnZsWaWa5kPevB8N2aWb7kI+tbGgsfozSy3chH0LZ7YzMxyLBdB39bY4MsrzSy3xg16SRsk7ZX09Bjr3y/pSUlPSfqppMtL1r2UyrdJ6qxkxSeipbHA4ZP9nOr3fDdmlj/l9Oi/BKw+y/oXgd+OiDcDnwTWj1j/9oi4IiI6zq2Kk9eSbpo6cNzDN2aWP+MGfUQ8DPScZf1PI+JA+vgIsLhCdauYtjQNgq+lN7M8qvQY/UeA75d8DuAHkrZKWne2HSWtk9QpqbO7u7uilRqa72b/MY/Tm1n+jPtw8HJJejvFoL++pPj6iNgj6UJgs6Tn0v8QXici1pOGfTo6OqJS9YLTM1j6yhszy6OK9OglvQX4ArAmIvYPlUfEnvS+F/gOcHUlvm+iWj10Y2Y5Numgl3QJ8G3gAxHxfEl5o6TmoWVgFTDqlTtT7YLZ9dTWyEM3ZpZL4w7dSLoPuAFok9QF3AXUA0TEvcCdQCvw15IA+tMVNguA76SyOuCrEfH3U9CGcdXUiPlzfNOUmeXTuEEfEbeMs/6jwEdHKd8FXP76PaqjtbGBfR66MbMcysWdsVB8SLh79GaWR7kJes9gaWZ5lZugb2sqeE56M8ul3AR9a2MDR07209s/UO2qmJlNq/wEvW+aMrOcylHQ+6YpM8un/AT98Hw3Dnozy5f8BH0auvGVN2aWNzkKeg/dmFk+5Sbomwt1NNTWsM/z3ZhZzuQm6CXR2tTgHr2Z5U5ugh48DYKZ5VOugr6lseCTsWaWO7kK+jbPYGlmOZSroG9tamD/sV4iKvqkQjOzGa2soJe0QdJeSaM+IUpFn5W0U9KTkq4sWXerpBfS69ZKVfxctDYVONk3yPFTnu/GzPKj3B79l4DVZ1l/E7AivdYBfwMgqYXiE6muofi82LskzT/Xyk6Wnx1rZnlUVtBHxMNAz1k2WQN8JYoeAeZJWgjcCGyOiJ6IOABs5ux/MKbU8E1TvpbezHKkUmP0i4DdJZ+7UtlY5a8jaZ2kTkmd3d3dFarWmVobh6ZBcI/ezPJjxpyMjYj1EdERER3t7e1T8h3u0ZtZHlUq6PcAS0o+L05lY5VXxVCP3pdYmlmeVCroNwIfTFffXAsciohXgE3AKknz00nYVamsKmY31NLYUOuhGzPLlbpyNpJ0H3AD0Capi+KVNPUAEXEv8ABwM7ATOA58KK3rkfRJYEs61N0RcbaTulOuJV1Lb2aWF2UFfUTcMs76AD42xroNwIaJV21qtDYWPN+NmeXKjDkZO13amjwNgpnlS+6CvtUTm5lZzuQv6NNUxYODnu/GzPIhh0FfoH8wOHyyr9pVMTObFvkL+qH5bnxC1sxyIn9B74eEm1nO5C/oh+e78QlZM8uH3AV9W+rR7/PQjZnlRO6Cfv7wnPTu0ZtZPuQu6Otra7hgdr3H6M0sN3IX9HD6WnozszzIZdC3NRbY56EbM8uJXAZ9a1ODr6M3s9zIbdC7R29meZHLoG9rKnDweB+n+gerXRUzsylXVtBLWi1ph6Sdkm4fZf3/lLQtvZ6XdLBk3UDJuo2VrPy5urB5FoB79WaWC+M+eERSLXAP8E6gC9giaWNEPDO0TUT8h5Lt/z3w1pJDnIiIKypX5cm7sLl4d+xrh09y8bzZVa6NmdnUKqdHfzWwMyJ2RcQp4H5gzVm2vwW4rxKVmyoL5hZ79HuPuEdvZtlXTtAvAnaXfO5KZa8jaSmwHPhRSfEsSZ2SHpH0O2N9iaR1abvO7u7uMqp17i6cW+zRO+jNLA8qfTJ2LfDNiBgoKVsaER3A+4DPSLpstB0jYn1EdERER3t7e4WrdabWxgYk6D58ckq/x8xsJign6PcAS0o+L05lo1nLiGGbiNiT3ncBP+bM8fuqqKutoa2pwGuH3aM3s+wrJ+i3ACskLZfUQDHMX3f1jKQ3AvOBn5WUzZdUSMttwNuAZ0buWw0XNhfYe8Q9ejPLvnGvuomIfkm3AZuAWmBDRGyXdDfQGRFDob8WuD8iSh/G+uvA5yUNUvyj8qnSq3WqqRj07tGbWfaNG/QAEfEA8MCIsjtHfP6zUfb7KfDmSdRvylzYPIunf3W42tUwM5tyubwzFmDB3OLEZv0DvjvWzLItt0HfPncWEX5IuJllX26Dfuju2L2+8sbMMs5B7ytvzCzjchv0ngbBzPIit0Hf3lxAglcPuUdvZtmW26Cvr62hvanAK4dOVLsqZmZTKrdBD3DxvNn86qB79GaWbbkO+kXzZvOrg+7Rm1m25TroL543iz0HT3DmrA1mZtmS86CfTW//ID2+acrMMiz3QQ94nN7MMi3XQb9oKOh95Y2ZZViug37hBcWbpnxC1syyLNdB39LYQKGuxkFvZplWVtBLWi1ph6Sdkm4fZf0fSOqWtC29Plqy7lZJL6TXrZWs/GRJSpdYeozezLJr3AePSKoF7gHeCXQBWyRtHOVJUV+LiNtG7NsC3AV0AAFsTfseqEjtK+DiebPpco/ezDKsnB791cDOiNgVEaeA+4E1ZR7/RmBzRPSkcN8MrD63qk6NJS1z2N1zvNrVMDObMuUE/SJgd8nnrlQ20u9JelLSNyUtmeC+SFonqVNSZ3d3dxnVqoxlrXPoOXaKQyf6pu07zcymU6VOxv4/YFlEvIVir/3LEz1ARKyPiI6I6Ghvb69Qtca3tLURgF/ud6/ezLKpnKDfAywp+bw4lQ2LiP0RMTSx+xeAq8rdt9qWtc0B4KX9x6pcEzOzqVFO0G8BVkhaLqkBWAtsLN1A0sKSj+8Bnk3Lm4BVkuZLmg+sSmUzxiUtxaB/2UFvZhk17lU3EdEv6TaKAV0LbIiI7ZLuBjojYiPwcUnvAfqBHuAP0r49kj5J8Y8FwN0R0TMF7ThncxrqWDC3wEseujGzjBo36AEi4gHggRFld5Ys3wHcMca+G4ANk6jjlFva2ugevZllVq7vjB2yvLXRPXozyywHPbC0bQ7dR3o52ttf7aqYmVWcgx64tK0JgF/sPVrlmpiZVZ6DHvi1i5oB2PHakSrXxMys8hz0FC+xLNTV8PyrDnozyx4HPVBbI1YsaHKP3swyyUGfvGFBM8876M0sgxz0yRsWNPPa4V4OHveDws0sWxz0ya8tSCdkPU5vZhnjoE9WXjwXgO2/OlzlmpiZVZaDPlkwdxYL5hZ4sutgtatiZlZRDvoSly+exxNdh6pdDTOzinLQl7h8yTxe3HeMQ8f9tCkzyw4HfYkrlswD4Mk9Hr4xs+xw0Jd48+ILkOCxlx30ZpYdDvoSc2fVs3LhXH62a1+1q2JmVjFlBb2k1ZJ2SNop6fZR1v+JpGckPSnpQUlLS9YNSNqWXhtH7jvT/OZlrTz28kFO9g1UuypmZhUxbtBLqgXuAW4CVgK3SFo5YrPHgY6IeAvwTeAvStadiIgr0us9Far3lPnNy9o4NTDI1pcPVLsqZmYVUU6P/mpgZ0TsiohTwP3AmtINIuKhiBh6RNMjwOLKVnP6/PPlLdTWiJ/s9PCNmWVDOUG/CNhd8rkrlY3lI8D3Sz7PktQp6RFJvzPWTpLWpe06u7u7y6jW1Ggq1HHV0vn86Lm9VauDmVklVfRkrKR/A3QAny4pXhoRHcD7gM9Iumy0fSNifUR0RERHe3t7Jas1YatWLuC5V4/4geFmlgnlBP0eYEnJ58Wp7AyS/iXwCeA9EdE7VB4Re9L7LuDHwFsnUd9pceObLgJg0/ZXq1wTM7PJKyfotwArJC2X1ACsBc64ekbSW4HPUwz5vSXl8yUV0nIb8DbgmUpVfqosaZnDyoVzeeApB72Znf/GDfqI6AduAzYBzwJfj4jtku6WNHQVzaeBJuAbIy6j/HWgU9ITwEPApyJixgc9wL+6chHbdh/0w0jM7LyniKh2HV6no6MjOjs7q1qH/Ud7ufZ/PMgHrl3Gne8eeTWpmdnMImlrOh/6Or4zdgytTQVWvekivv14F8dP9Ve7OmZm58xBfxYfftsyDh7v46uP/rLaVTEzO2cO+rO4amkL113ayvqHd3lKBDM7bznox/Hxd6xg75FevvCPu6pdFTOzc+KgH8d1l7Vy029cxOce2snunuPj72BmNsM46Mvwp+9aSY3Ef/rGE/QPDFa7OmZmE+KgL8PF82Zz95rf4NEXe/jMD1+odnXMzCakrtoVOF+896rF/PzF/XzuoZ1cPG8277vmkmpXycysLA76Cfjvv/tmuo/08om/e4qBwUE+cN2yalfJzGxcHrqZgPraGv76/VfxjjdeyJ9+dzt3ffdpX3ZpZjOeg36CZjfU8vkPdPCR65fz5Z+9zLv/1z+x5aWealfLzGxMDvpzUFsj/vRdK/nKh6/myMl+/vW9P+OjX+7k8V/68YNmNvN4UrNJOnFqgA0/eZF7f/wLjvT2c/niC/i9qxZz45suYsHcWdWunpnlxNkmNXPQV8jR3n6+/VgX//eRl3n+taNIcMWSeVx7aSvXLG/hqqXzaZ5VX+1qmllGOein2QuvHeH7T7/KQzv28lTXIfoHiz/jJS2zeeNFc3njRc0sb2tkScscFs+fzYLmWdTUqMq1NrPz2dmCvqzLKyWtBv4KqAW+EBGfGrG+AHwFuArYD/x+RLyU1t1B8YHhA8DHI2LTObbjvLFiQTMrFjTz8Xes4Pipfh57+SCP//IAz712hB2vHuHBZ19jsOTva32tWHjBbNqaGmhtKhTfGwu0NjUwf04DzbPqaCrU0TSrjuZCffHzrDrqa32KxczGN27QS6oF7gHeCXQBWyRtHPGkqI8AByLin0laC/w58PuSVlJ89OCbgIuBH0p6Q0Tk5prEOQ11XL+ijetXtA2XnewbYM/BE3QdOMHunuN0HTjBrw6eYP+xXnb3HOfxXx6k51jvGX8MRlOoq2F2Qy2z6mop1Ne87r0w4nN9jaitqaG+VtTVpuUaUVsr6mtqqKsVdTWirrYmvYu6mprhshpBjYRUPCE9tFwjpc8gFcuHtq2RqKkpWS4pLz3O6X2LbZOEAAmEYLicVK6houFtVLINnC4bdXv5f1CWH+X06K8GdqaHeyPpfmANZz77dQ3wZ2n5m8DnVPyXtAa4Pz0s/EVJO9PxflaZ6p+fZtXXcll7E5e1N425zeBgcPBEHweOn+JYbz9HT/Zz+GQ/R3v7OXqyj6O9/Rw52c/JvgFO9g1ysn+A3pL3o7397D96avhzb/8AfQPBwGDQNzDIwGAMDynlmU7/DXn9HxeA4T8UI/6YlGzPiGOM9h2vKxu1LqPsO0adyzli+d872naVbcdoRj3eNNSlzB/fmCbaRZhIp6JlTgNf/7fXTfAbxldO0C8Cdpd87gKuGWubiOiXdAhoTeWPjNh30WhfImkdsA7gkks8vUBNjWhpbKClsWHKviOiGPal4d83EPQPDtI/UFzXPzCY3oOguO1gFPcdDBiMYHCwZHnoNTj0ubjtQJTuFwyk9SOPE0Ck7YaXS+rLcFnp8umy0lNOEXHGeji9TaSdTx97xLrS7xk+xvjfOdrP+HVlo243StkoW46+XXnHG23LUY83DXUZ7XhlFk3yZ1re8cYy4a7RBHdonjU1kxXMmCkQImI9sB6KJ2OrXJ1ckER9raivLf4vw8yyqZyzeXuAJSWfF6eyUbeRVAdcQPGkbDn7mpnZFCon6LcAKyQtl9RA8eTqxhHbbARuTcvvBX4Uxf8PbQTWSipIWg6sAH5emaqbmVk5xh26SWPutwGbKF5euSEitku6G+iMiI3AF4H/k0629lD8Y0Da7usUT9z2Ax/L0xU3ZmYzgW+YMjPLgLPdMOU7bszMMs5Bb2aWcQ56M7OMc9CbmWXcjDwZK6kbePkcd28D9lWwOucDtzn78tZecJsnamlEtI+2YkYG/WRI6hzrzHNWuc3Zl7f2gttcSR66MTPLOAe9mVnGZTHo11e7AlXgNmdf3toLbnPFZG6M3szMzpTFHr2ZmZVw0JuZZVxmgl7Sakk7JO2UdHu16zMZkjZI2ivp6ZKyFkmbJb2Q3uenckn6bGr3k5KuLNnn1rT9C5JuHe27ZgpJSyQ9JOkZSdsl/VEqz2y7Jc2S9HNJT6Q2/7dUvlzSo6ltX0vTg5Om+/5aKn9U0rKSY92RyndIurE6LSqPpFpJj0v6Xvqc9fa+JOkpSdskdaay6f29jvQ4t/P5RXH65F8AlwINwBPAymrXaxLt+S3gSuDpkrK/AG5Py7cDf56Wbwa+T/FRltcCj6byFmBXep+fludXu21nafNC4Mq03Aw8D6zMcrtT3ZvScj3waGrL14G1qfxe4A/T8r8D7k3La4GvpeWV6Xe+ACxP/xZqq92+s7T7T4CvAt9Ln7Pe3peAthFl0/p7XfUfQoV+kNcBm0o+3wHcUe16TbJNy0YE/Q5gYVpeCOxIy58Hbhm5HXAL8PmS8jO2m+kv4LvAO/PSbmAO8BjF5zHvA+pS+fDvNsVnQlyXluvSdhr5+1663Ux7UXzK3IPAvwC+l+qf2fam+o0W9NP6e52VoZvRHmA+6kPIz2MLIuKVtPwqsCAtj9X28/Znkv6L/laKPdxMtzsNY2wD9gKbKfZOD0ZEf9qktP7DbUvrDwGtnF9t/gzwX4DB9LmVbLcXio8I/4GkrZLWpbJp/b2eMQ8Ht/JFREjK5HWxkpqAbwF/HBGHJQ2vy2K7o/jEtSskzQO+A7yxylWaMpLeBeyNiK2Sbqh2fabR9RGxR9KFwGZJz5WunI7f66z06PPwEPLXJC0ESO97U/lYbT/vfiaS6imG/N9GxLdTcebbDRARB4GHKA5dzJM01Akrrf9w29L6C4D9nD9tfhvwHkkvAfdTHL75K7LbXgAiYk9630vxj/nVTPPvdVaCvpwHmJ/vSh/AfivFMeyh8g+ms/XXAofSfwk3AaskzU9n9FelshlJxa77F4FnI+IvS1Zltt2S2lNPHkmzKZ6TeJZi4L83bTayzUM/i/cCP4rigO1GYG26SmU5sAL4+fS0onwRcUdELI6IZRT/jf4oIt5PRtsLIKlRUvPQMsXfx6eZ7t/rap+oqOAJj5spXqnxC+AT1a7PJNtyH/AK0EdxLO4jFMcmHwReAH4ItKRtBdyT2v0U0FFynA8DO9PrQ9Vu1zhtvp7iWOaTwLb0ujnL7QbeAjye2vw0cGcqv5RicO0EvgEUUvms9HlnWn9pybE+kX4WO4Cbqt22Mtp+A6evuslse1Pbnkiv7UPZNN2/154Cwcws47IydGNmZmNw0JuZZZyD3sws4xz0ZmYZ56A3M8s4B72ZWcY56M3MMu7/A7pff8eejgM4AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(loss_history)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: accuracy= 0.6160\n" + ] + } + ], + "source": [ + "test(gcn)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First model parameters: 144107\n", + "Second model parameters: 72057\n" + ] + } + ], + "source": [ + "# Snippet taken from: https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model\n", + "net_params = sum(p.numel() for p in net.parameters() if p.requires_grad) \n", + "gcn_params = sum(p.numel() for p in gcn.parameters() if p.requires_grad)\n", + "\n", + "print('First model parameters: ', net_params)\n", + "print('Second model parameters: ', gcn_params)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pygcn/utils.py b/pygcn/utils.py index 9b53c5b..02db573 100644 --- a/pygcn/utils.py +++ b/pygcn/utils.py @@ -1,6 +1,7 @@ import numpy as np import scipy.sparse as sp import torch +from scipy.linalg import fractional_matrix_power as matrix_frac_power def encode_onehot(labels): @@ -12,7 +13,7 @@ def encode_onehot(labels): return labels_onehot -def load_data(path="../data/cora/", dataset="cora"): +def load_data(path="../data/cora/", dataset="cora", symm_norm = False): """Load citation network dataset (cora only for now)""" print('Loading {} dataset...'.format(dataset)) @@ -32,12 +33,19 @@ def load_data(path="../data/cora/", dataset="cora"): shape=(labels.shape[0], labels.shape[0]), dtype=np.float32) - # build symmetric adjacency matrix - adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) - + if not symm_norm: + # use the D^-1A formula + adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) + adj = normalize(adj + sp.eye(adj.shape[0])) + else: + # use the D^-1/2AD^-1/2 formula + adj = adj + sp.eye(adj.shape[0]) # add self loop + D = np.diag(np.array(adj.sum(axis=1)).flatten()) # build degree matrix + D_prime = matrix_frac_power(D,-0.5) + D_prime = sp.coo_matrix(D_prime, shape=(adj.shape[0],adj.shape[0]),dtype=np.float32) # convert to sparse format + adj = D_prime @ adj @ D_prime # compute the normalized symmetric version + features = normalize(features) - adj = normalize(adj + sp.eye(adj.shape[0])) - idx_train = range(140) idx_val = range(200, 500) idx_test = range(500, 1500)