diff --git a/Attack_baselines/Membership_Inference/ml_leaks/ml-leaks_adversary_1-SST.ipynb b/Attack_baselines/Membership_Inference/ml_leaks/ml-leaks_adversary_1-SST.ipynb new file mode 100644 index 0000000..772943a --- /dev/null +++ b/Attack_baselines/Membership_Inference/ml_leaks/ml-leaks_adversary_1-SST.ipynb @@ -0,0 +1,670 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Membership inference on text\n", + "### Stanford Sentiment Treebank (SST) movie review dataset for sentiment analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.7.0 (default, Jun 28 2018, 13:15:42) \n", + "[GCC 7.2.0]\n", + "Pytorch: 1.0.0\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from torchtext import data\n", + "from torchtext import datasets \n", + "import sys\n", + "import seaborn as sns\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "sys.path.insert(0, '../../../Utils/')\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline \n", + "\n", + "import models\n", + "from train import *\n", + "from metrics import * \n", + "\n", + "print(\"Python: %s\" % sys.version)\n", + "print(\"Pytorch: %s\" % torch.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load SST using Torchtext" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# To fix the following error: OSError: [E050] Can't find model 'en'. It doesn't seem to be a shortcut link, a Python package or a valid path to a data directory.\n", + "# Run: \n", + "# python -m spacy download en\n", + "\n", + "\n", + "TEXT = data.Field(tokenize='spacy')\n", + "LABEL = data.LabelField(tensor_type=torch.LongTensor)\n", + "\n", + "target, val, shadow = datasets.SST.splits(TEXT, LABEL, root='../../../Datasets/SST_data', fine_grained=True)\n", + "\n", + "target_in, target_out = target.split(split_ratio=0.5)\n", + "shadow_in, shadow_out = shadow.split(split_ratio=0.5)\n", + "\n", + "TEXT.build_vocab(target, max_size=25000, vectors=\"glove.6B.100d\", vectors_cache='../../../Datasets/SST_data/vector_cache')\n", + "LABEL.build_vocab(target)\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "shadow_in_itr, shadow_out_itr, target_in_itr, target_out_itr, val_itr = data.BucketIterator.splits(\n", + " (shadow_in, shadow_out, target_in, target_out, val), \n", + " batch_size = BATCH_SIZE, \n", + " sort_key= lambda x: len(x.text), \n", + " repeat=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create bidirectional LSTM model for sentiment analysis" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "vocab_size = len(TEXT.vocab)\n", + "embedding_size = 100\n", + "hidden_size = 256\n", + "output_size = 5\n", + "\n", + "\n", + "target_model = models.RNN(vocab_size, embedding_size, hidden_size, output_size)\n", + "shadow_model = models.RNN(vocab_size, embedding_size, hidden_size, output_size)\n", + "\n", + "pretrained_embeddings = TEXT.vocab.vectors\n", + "target_model.embedding.weight.data.copy_(pretrained_embeddings)\n", + "shadow_model.embedding.weight.data.copy_(pretrained_embeddings)\n", + "print(\"\")\n", + "\n", + "target_optimizer = torch.optim.Adam(target_model.parameters())\n", + "shadow_optimizer = torch.optim.Adam(shadow_model.parameters())\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "target_model = target_model.to(device)\n", + "shadow_model = shadow_model.to(device)\n", + "criterion = criterion.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def classification_accuracy(preds, y):\n", + "\n", + " correct = (preds == y).float() #convert into float for division \n", + " acc = correct.sum()/len(correct)\n", + " return acc\n", + "\n", + "def binary_accuracy(preds, y):\n", + "\n", + " rounded_preds = torch.round(preds)\n", + "\n", + " correct = (rounded_preds == y).float() #convert into float for division \n", + " acc = correct.sum()/len(correct)\n", + " return acc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation functions" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(model, iterator, criterion):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " model.eval()\n", + " \n", + " with torch.no_grad():\n", + " \n", + " for batch in iterator:\n", + "\n", + " predictions = model(batch.text).squeeze(1)\n", + " loss = criterion(predictions, batch.label)\n", + " acc = classification_accuracy(predictions.argmax(dim=1), batch.label)\n", + "\n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " return epoch_loss / len(iterator), epoch_acc / len(iterator)\n", + "\n", + "def evaluate_inference(target_model, attack_model, in_iterator, out_iterator, criterion, k):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " shadow_model.eval()\n", + " attack_model.eval()\n", + " \n", + " predictions = np.array([])\n", + " labels = np.array([])\n", + " \n", + " with torch.no_grad():\n", + " \n", + " for in_batch, out_batch in zip(in_iterator, out_iterator):\n", + "\n", + " in_size = len(in_batch.label)\n", + " out_size = len(out_batch.label)\n", + " in_lbl = torch.ones(in_size).to(device)\n", + " out_lbl = torch.zeros(out_size).to(device)\n", + " \n", + " \n", + " in_predictions = torch.nn.functional.softmax(target_model(in_batch.text.detach()), dim=1).detach()\n", + " out_predictions = torch.nn.functional.softmax(target_model(out_batch.text.detach()), dim=1).detach()\n", + " \n", + " in_sort, _ = torch.sort(in_predictions, descending=True)\n", + " in_top_k = in_sort[:,:k].clone().to(device)\n", + "\n", + " out_sort, _ = torch.sort(out_predictions, descending=True)\n", + " out_top_k = out_sort[:,:k].clone().to(device)\n", + " \n", + " \n", + " in_inference = attack_model(in_top_k).squeeze(1)\n", + " out_inference = attack_model(out_top_k).squeeze(1)\n", + " \n", + " in_probability = torch.nn.functional.sigmoid(in_inference).detach().cpu().numpy()\n", + " out_probability = torch.nn.functional.sigmoid(out_inference).detach().cpu().numpy()\n", + " \n", + " loss = (criterion(in_inference, in_lbl) + criterion(out_inference, out_lbl)) / 2\n", + " acc = (binary_accuracy(in_inference, in_lbl) + binary_accuracy(out_inference, out_lbl)) / 2\n", + " \n", + " predictions = np.concatenate((predictions, in_probability), axis=0)\n", + " labels = np.concatenate((labels, np.ones(in_size)), axis=0)\n", + " predictions = np.concatenate((predictions, out_probability), axis=0)\n", + " labels = np.concatenate((labels, np.zeros(out_size)), axis=0)\n", + "\n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " fpr, tpr, thresholds = roc_curve(labels, predictions, pos_label=1)\n", + " roc_auc = auc(fpr, tpr)\n", + " \n", + " return epoch_loss / len(in_iterator), epoch_acc / len(in_iterator), fpr, tpr, roc_auc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train functions" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, iterator, optimizer, criterion):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " model.train()\n", + " \n", + " for batch in iterator:\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " predictions = model(batch.text).squeeze(1)\n", + " \n", + " loss = criterion(predictions, batch.label)\n", + "\n", + " acc = classification_accuracy(predictions.argmax(dim=1), batch.label)\n", + " \n", + " loss.backward()\n", + " \n", + " optimizer.step()\n", + " \n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " return epoch_loss / len(iterator), epoch_acc / len(iterator)\n", + "\n", + "def train_nlp_attack(shadow_model, attack_model, in_iterator, out_iterator, optimizer, criterion, k):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " shadow_model.eval()\n", + " attack_model.train()\n", + " \n", + " in_input = np.empty((0,2))\n", + " out_input = np.empty((0,2))\n", + " \n", + " for in_batch, out_batch in zip(in_iterator, out_iterator):\n", + "\n", + " optimizer.zero_grad()\n", + " \n", + " in_predictions = torch.nn.functional.softmax(shadow_model(in_batch.text.detach()), dim=1).detach()\n", + " out_predictions = torch.nn.functional.softmax(shadow_model(out_batch.text.detach()), dim=1).detach()\n", + "\n", + " in_lbl = torch.ones(len(in_batch.label)).to(device)\n", + " out_lbl = torch.zeros(len(out_batch.label)).to(device)\n", + " \n", + " in_sort, _ = torch.sort(in_predictions, descending=True)\n", + " in_top_k = in_sort[:,:k].clone().to(device)\n", + "\n", + " out_sort, _ = torch.sort(out_predictions, descending=True)\n", + " out_top_k = out_sort[:,:k].clone().to(device)\n", + "\n", + " in_inference = attack_model(in_top_k).squeeze(1)\n", + " out_inference = attack_model(out_top_k).squeeze(1)\n", + " \n", + " in_input = np.vstack((in_input, torch.cat((torch.max(in_predictions, dim=1, keepdim=True)[0], in_batch.label.view(-1,1).type(torch.cuda.FloatTensor)), dim=1).cpu().numpy() ))\n", + " out_input = np.vstack((out_input, torch.cat((torch.max(out_predictions, dim=1, keepdim=True)[0], out_batch.label.view(-1,1).type(torch.cuda.FloatTensor)), dim=1).cpu().numpy() ))\n", + " \n", + " loss = (criterion(in_inference, in_lbl) + criterion(out_inference, out_lbl)) / 2\n", + " acc = (binary_accuracy(in_inference, in_lbl) + binary_accuracy(out_inference, out_lbl)) / 2\n", + "\n", + " loss.backward()\n", + " \n", + " optimizer.step()\n", + " \n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " \n", + " return epoch_loss / len(in_iterator), epoch_acc / len(in_iterator), in_input, out_input" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train target model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ljt/cyphercat/venv/lib/python3.7/site-packages/torchtext/data/field.py:322: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", + " return Variable(arr, volatile=not train)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 01, Train Loss: 1.534, Train Acc: 31.37%, Val. Loss: 1.494, Val. Acc: 29.00%\n", + "Epoch: 02, Train Loss: 1.428, Train Acc: 36.94%, Val. Loss: 1.358, Val. Acc: 39.76%\n", + "Epoch: 03, Train Loss: 1.320, Train Acc: 42.89%, Val. Loss: 1.346, Val. Acc: 40.12%\n", + "Epoch: 04, Train Loss: 1.227, Train Acc: 46.99%, Val. Loss: 1.411, Val. Acc: 40.08%\n", + "Epoch: 05, Train Loss: 1.132, Train Acc: 50.07%, Val. Loss: 1.420, Val. Acc: 41.05%\n", + "Epoch: 06, Train Loss: 1.034, Train Acc: 55.13%, Val. Loss: 1.440, Val. Acc: 40.20%\n", + "Epoch: 07, Train Loss: 0.916, Train Acc: 61.05%, Val. Loss: 1.526, Val. Acc: 42.53%\n", + "Epoch: 08, Train Loss: 0.801, Train Acc: 67.23%, Val. Loss: 1.693, Val. Acc: 40.83%\n", + "Epoch: 09, Train Loss: 0.714, Train Acc: 71.64%, Val. Loss: 1.942, Val. Acc: 37.84%\n", + "Epoch: 10, Train Loss: 0.619, Train Acc: 76.31%, Val. Loss: 2.082, Val. Acc: 37.26%\n", + "Epoch: 11, Train Loss: 0.557, Train Acc: 77.94%, Val. Loss: 2.376, Val. Acc: 38.37%\n", + "Epoch: 12, Train Loss: 0.485, Train Acc: 81.09%, Val. Loss: 2.239, Val. Acc: 37.80%\n", + "Epoch: 13, Train Loss: 0.436, Train Acc: 83.21%, Val. Loss: 2.412, Val. Acc: 37.66%\n", + "Epoch: 14, Train Loss: 0.385, Train Acc: 85.56%, Val. Loss: 2.344, Val. Acc: 38.51%\n", + "Epoch: 15, Train Loss: 0.332, Train Acc: 87.48%, Val. Loss: 2.693, Val. Acc: 36.37%\n", + "Epoch: 16, Train Loss: 0.298, Train Acc: 89.16%, Val. Loss: 2.764, Val. Acc: 37.17%\n", + "Epoch: 17, Train Loss: 0.286, Train Acc: 89.53%, Val. Loss: 2.771, Val. Acc: 38.42%\n", + "Epoch: 18, Train Loss: 0.244, Train Acc: 90.88%, Val. Loss: 2.959, Val. Acc: 36.68%\n", + "Epoch: 19, Train Loss: 0.231, Train Acc: 90.88%, Val. Loss: 2.934, Val. Acc: 37.53%\n", + "Epoch: 20, Train Loss: 0.212, Train Acc: 92.14%, Val. Loss: 2.950, Val. Acc: 36.33%\n", + "Epoch: 21, Train Loss: 0.204, Train Acc: 92.72%, Val. Loss: 2.958, Val. Acc: 36.32%\n", + "Epoch: 22, Train Loss: 0.164, Train Acc: 94.05%, Val. Loss: 3.125, Val. Acc: 35.30%\n", + "Epoch: 23, Train Loss: 0.159, Train Acc: 93.98%, Val. Loss: 3.307, Val. Acc: 37.30%\n", + "Epoch: 24, Train Loss: 0.144, Train Acc: 94.68%, Val. Loss: 3.269, Val. Acc: 37.12%\n", + "Epoch: 25, Train Loss: 0.132, Train Acc: 94.85%, Val. Loss: 3.395, Val. Acc: 37.48%\n", + "Epoch: 26, Train Loss: 0.130, Train Acc: 95.31%, Val. Loss: 3.426, Val. Acc: 37.44%\n", + "Epoch: 27, Train Loss: 0.116, Train Acc: 95.85%, Val. Loss: 3.488, Val. Acc: 36.90%\n", + "Epoch: 28, Train Loss: 0.119, Train Acc: 95.69%, Val. Loss: 3.380, Val. Acc: 38.24%\n", + "Epoch: 29, Train Loss: 0.109, Train Acc: 96.39%, Val. Loss: 3.585, Val. Acc: 38.10%\n", + "Epoch: 30, Train Loss: 0.091, Train Acc: 96.67%, Val. Loss: 3.765, Val. Acc: 36.77%\n" + ] + } + ], + "source": [ + "n_epochs_classification = 30\n", + "\n", + "for epoch in range(n_epochs_classification):\n", + "\n", + " train_loss, train_acc = train(target_model, target_in_itr, target_optimizer, criterion)\n", + " valid_loss, valid_acc = evaluate(target_model, val_itr, criterion)\n", + " \n", + " print('Epoch: %02d, Train Loss: %.3f, Train Acc: %.2f%%, Val. Loss: %.3f, Val. Acc: %.2f%%' % (epoch+1, train_loss, train_acc*100, valid_loss, valid_acc*100))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train shadow model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 01, Train Loss: 1.583, Train Acc: 25.24%, Val. Loss: 1.581, Val. Acc: 29.59%\n", + "Epoch: 02, Train Loss: 1.535, Train Acc: 31.43%, Val. Loss: 1.527, Val. Acc: 34.46%\n", + "Epoch: 03, Train Loss: 1.516, Train Acc: 32.16%, Val. Loss: 1.524, Val. Acc: 30.57%\n", + "Epoch: 04, Train Loss: 1.443, Train Acc: 38.14%, Val. Loss: 1.444, Val. Acc: 36.02%\n", + "Epoch: 05, Train Loss: 1.389, Train Acc: 41.27%, Val. Loss: 1.427, Val. Acc: 36.68%\n", + "Epoch: 06, Train Loss: 1.325, Train Acc: 43.11%, Val. Loss: 1.422, Val. Acc: 34.77%\n", + "Epoch: 07, Train Loss: 1.264, Train Acc: 45.48%, Val. Loss: 1.404, Val. Acc: 37.04%\n", + "Epoch: 08, Train Loss: 1.210, Train Acc: 50.19%, Val. Loss: 1.474, Val. Acc: 35.79%\n", + "Epoch: 09, Train Loss: 1.113, Train Acc: 51.00%, Val. Loss: 1.475, Val. Acc: 34.85%\n", + "Epoch: 10, Train Loss: 1.027, Train Acc: 58.37%, Val. Loss: 1.677, Val. Acc: 35.08%\n", + "Epoch: 11, Train Loss: 0.961, Train Acc: 60.06%, Val. Loss: 1.499, Val. Acc: 35.70%\n", + "Epoch: 12, Train Loss: 0.849, Train Acc: 66.19%, Val. Loss: 1.735, Val. Acc: 37.49%\n", + "Epoch: 13, Train Loss: 0.793, Train Acc: 68.87%, Val. Loss: 1.742, Val. Acc: 35.92%\n", + "Epoch: 14, Train Loss: 0.730, Train Acc: 71.60%, Val. Loss: 1.662, Val. Acc: 37.40%\n", + "Epoch: 15, Train Loss: 0.643, Train Acc: 76.29%, Val. Loss: 1.827, Val. Acc: 36.24%\n", + "Epoch: 16, Train Loss: 0.603, Train Acc: 78.06%, Val. Loss: 1.831, Val. Acc: 38.06%\n", + "Epoch: 17, Train Loss: 0.576, Train Acc: 76.81%, Val. Loss: 1.795, Val. Acc: 36.06%\n", + "Epoch: 18, Train Loss: 0.520, Train Acc: 80.04%, Val. Loss: 1.995, Val. Acc: 35.12%\n", + "Epoch: 19, Train Loss: 0.474, Train Acc: 83.31%, Val. Loss: 2.028, Val. Acc: 39.31%\n", + "Epoch: 20, Train Loss: 0.409, Train Acc: 84.67%, Val. Loss: 2.139, Val. Acc: 36.68%\n", + "Epoch: 21, Train Loss: 0.340, Train Acc: 87.61%, Val. Loss: 2.279, Val. Acc: 35.70%\n", + "Epoch: 22, Train Loss: 0.366, Train Acc: 85.84%, Val. Loss: 2.266, Val. Acc: 36.86%\n", + "Epoch: 23, Train Loss: 0.332, Train Acc: 87.98%, Val. Loss: 2.368, Val. Acc: 36.01%\n", + "Epoch: 24, Train Loss: 0.326, Train Acc: 88.69%, Val. Loss: 2.164, Val. Acc: 35.43%\n", + "Epoch: 25, Train Loss: 0.249, Train Acc: 90.66%, Val. Loss: 2.454, Val. Acc: 37.04%\n", + "Epoch: 26, Train Loss: 0.237, Train Acc: 91.28%, Val. Loss: 2.623, Val. Acc: 35.08%\n", + "Epoch: 27, Train Loss: 0.229, Train Acc: 90.99%, Val. Loss: 2.451, Val. Acc: 35.21%\n", + "Epoch: 28, Train Loss: 0.216, Train Acc: 91.99%, Val. Loss: 2.821, Val. Acc: 35.34%\n", + "Epoch: 29, Train Loss: 0.209, Train Acc: 92.69%, Val. Loss: 2.715, Val. Acc: 35.79%\n", + "Epoch: 30, Train Loss: 0.177, Train Acc: 94.20%, Val. Loss: 2.807, Val. Acc: 35.34%\n" + ] + } + ], + "source": [ + "for epoch in range(n_epochs_classification):\n", + "\n", + " train_loss, train_acc = train(shadow_model, shadow_in_itr, shadow_optimizer, criterion)\n", + " valid_loss, valid_acc = evaluate(shadow_model, val_itr, criterion)\n", + " \n", + " print('Epoch: %02d, Train Loss: %.3f, Train Acc: %.2f%%, Val. Loss: %.3f, Val. Acc: %.2f%%' % (epoch+1, train_loss, train_acc*100, valid_loss, valid_acc*100))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluate models on test set" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target net test accuracy: 0.37 , Shadow net test accuracy: 0.35\n" + ] + } + ], + "source": [ + "target_test_loss, target_test_acc = evaluate(target_model, val_itr, criterion)\n", + "shadow_test_loss, shadow_test_acc = evaluate(shadow_model, val_itr, criterion)\n", + "\n", + "print('Target net test accuracy: %.2f , Shadow net test accuracy: %.2f' % (target_test_acc, shadow_test_acc))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create ML-leaks adversary 1 model (multi layer perceptron)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "attack_model = models.mlp(3,1,32).to(device)\n", + "\n", + "attack_loss = torch.nn.BCELoss()\n", + "attack_optim = torch.optim.Adam(attack_model.parameters(), lr=0.01)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train attack model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ljt/cyphercat/venv/lib/python3.7/site-packages/torch/nn/functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n", + " warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 01, Train Loss: 0.572, Train Acc: 71.05%, Val. Loss: 0.488, Val. Acc: 76.08%\n", + "Epoch: 02, Train Loss: 0.501, Train Acc: 77.53%, Val. Loss: 0.480, Val. Acc: 77.65%\n", + "Epoch: 03, Train Loss: 0.504, Train Acc: 77.65%, Val. Loss: 0.490, Val. Acc: 78.70%\n", + "Epoch: 04, Train Loss: 0.494, Train Acc: 77.71%, Val. Loss: 0.472, Val. Acc: 77.88%\n", + "Epoch: 05, Train Loss: 0.487, Train Acc: 78.25%, Val. Loss: 0.467, Val. Acc: 78.70%\n", + "Epoch: 06, Train Loss: 0.491, Train Acc: 78.27%, Val. Loss: 0.476, Val. Acc: 78.49%\n", + "Epoch: 07, Train Loss: 0.486, Train Acc: 78.20%, Val. Loss: 0.467, Val. Acc: 79.91%\n", + "Epoch: 08, Train Loss: 0.489, Train Acc: 78.12%, Val. Loss: 0.466, Val. Acc: 78.52%\n", + "Epoch: 09, Train Loss: 0.482, Train Acc: 78.34%, Val. Loss: 0.471, Val. Acc: 78.66%\n", + "Epoch: 10, Train Loss: 0.484, Train Acc: 78.38%, Val. Loss: 0.477, Val. Acc: 78.59%\n", + "Epoch: 11, Train Loss: 0.482, Train Acc: 78.74%, Val. Loss: 0.469, Val. Acc: 78.85%\n", + "Epoch: 12, Train Loss: 0.480, Train Acc: 79.07%, Val. Loss: 0.479, Val. Acc: 78.65%\n", + "Epoch: 13, Train Loss: 0.482, Train Acc: 78.56%, Val. Loss: 0.473, Val. Acc: 78.66%\n", + "Epoch: 14, Train Loss: 0.478, Train Acc: 78.75%, Val. Loss: 0.482, Val. Acc: 78.64%\n", + "Epoch: 15, Train Loss: 0.481, Train Acc: 78.79%, Val. Loss: 0.490, Val. Acc: 78.40%\n", + "Epoch: 16, Train Loss: 0.478, Train Acc: 78.91%, Val. Loss: 0.488, Val. Acc: 78.44%\n", + "Epoch: 17, Train Loss: 0.474, Train Acc: 78.84%, Val. Loss: 0.482, Val. Acc: 78.65%\n", + "Epoch: 18, Train Loss: 0.477, Train Acc: 78.78%, Val. Loss: 0.494, Val. Acc: 78.42%\n", + "Epoch: 19, Train Loss: 0.478, Train Acc: 78.72%, Val. Loss: 0.476, Val. Acc: 78.85%\n", + "Epoch: 20, Train Loss: 0.477, Train Acc: 79.09%, Val. Loss: 0.476, Val. Acc: 78.68%\n", + "Epoch: 21, Train Loss: 0.477, Train Acc: 78.84%, Val. Loss: 0.499, Val. Acc: 78.32%\n", + "Epoch: 22, Train Loss: 0.477, Train Acc: 79.24%, Val. Loss: 0.499, Val. Acc: 78.24%\n", + "Epoch: 23, Train Loss: 0.475, Train Acc: 79.14%, Val. Loss: 0.512, Val. Acc: 78.16%\n", + "Epoch: 24, Train Loss: 0.477, Train Acc: 78.78%, Val. Loss: 0.512, Val. Acc: 78.17%\n", + "Epoch: 25, Train Loss: 0.477, Train Acc: 78.78%, Val. Loss: 0.513, Val. Acc: 78.14%\n", + "Epoch: 26, Train Loss: 0.478, Train Acc: 78.71%, Val. Loss: 0.493, Val. Acc: 78.42%\n", + "Epoch: 27, Train Loss: 0.473, Train Acc: 79.00%, Val. Loss: 0.500, Val. Acc: 78.36%\n", + "Epoch: 28, Train Loss: 0.475, Train Acc: 78.88%, Val. Loss: 0.512, Val. Acc: 78.10%\n", + "Epoch: 29, Train Loss: 0.477, Train Acc: 78.88%, Val. Loss: 0.505, Val. Acc: 78.33%\n", + "Epoch: 30, Train Loss: 0.472, Train Acc: 78.46%, Val. Loss: 0.515, Val. Acc: 78.00%\n" + ] + } + ], + "source": [ + "n_epochs_attack = 30\n", + "\n", + "for epoch in range(n_epochs_attack):\n", + "\n", + " train_loss, train_acc, in_input, out_input = train_nlp_attack(shadow_model, attack_model, shadow_in_itr, shadow_out_itr, attack_optim, attack_loss, 3)\n", + " valid_loss, valid_acc, fpr, tpr, roc_auc = evaluate_inference(target_model, attack_model, target_in_itr, target_out_itr, attack_loss, 3)\n", + "\n", + " print('Epoch: %02d, Train Loss: %.3f, Train Acc: %.2f%%, Val. Loss: %.3f, Val. Acc: %.2f%%' % (epoch+1, train_loss, train_acc*100, valid_loss, valid_acc*100))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Attack Results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEWCAYAAAB42tAoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzs3Xd4VGX2wPHvSW+0JDTpTQhNQKQIAqIgir0slsXGT0XsBVFRcVF3FRVFaequshbWXduqCLiABRWRovQmSAu9pdeZOb8/7mQIIYQAmUwmOZ/nmYdb3rn3zGS45977vvd9RVUxxhhjjiUk0AEYY4yp2CxRGGOMKZElCmOMMSWyRGGMMaZEliiMMcaUyBKFMcaYElmiMH4nIv1EJLmc9rVFRM4/xrpzRGR9OcVxp4jsEZEMEUkoj30a4y+WKCoIEektIgtEJFVEDorITyJylnddhIi8LCLJ3gPPFhF51bsuo9DLIyLZheZvKGY/00REReSyIstf8S6/uVw+cACo6g+q2vpk3isiN4vIj6UsGw6MBwaqapyqHjiZfQaa93dW8Hva7f3txBUpc7aIfCMi6d7f7pci0rZImeoi8qqIbPNua5N3PrF8P5E5WZYoKgARqQ7MAF4H4oEGwF+AXG+Rx4CuQDegGtAP+BXAeyCKU9U4YBtwSaFlHxxjlxuAGwvtPwz4E7CpjD9amfPGWtHVBaKA1SfzZhEJLdtwTskl3t9WJ6Azzm8RABHpCfwP+Bw4DWgGLAd+EpHm3jIRwDygHTAIqA70BA7g/J79Ikh+J0HDEkXFcDqAqv5LVd2qmq2q/1PVFd71ZwGfqepOdWxR1XdPYX9fAr1FpJZ3fhCwAthduJCI3Coia0XkkIh8LSJNCq1TERkhIr97zyafEZEW3quiNBH5j/cgUXh7j4vIfu+Z6g2FlkeKyEveM849IjJVRKK96/p5r6RGichu4B0RSRSRGSKS4r36+kFECv+WO4nICu8Z7r9FJKrwtgrtd4uIPCYia7yf8Z2Cssfjfe/DRfcjIqcDBbe3UkTkG2/5NiIyxxvvehH5U6FtTRORKSIyU0QygXNL+Z08JCJ7RWSXiNxSaHvR3ivQrd7Yfiz03h7ev1GKiCwXkX6l+byquhv4GidhFBgHvKuqE1Q1XVUPquoTwELgaW+ZG4HGwBWqukZVPaq6V1WfUdWZx/hu2xX6rvaIyOOFvqdnC5Ur7u85SkRWAJne6Y+LbHuCiLzmna4hIv/wfn87ROTZCpakKwxLFBXDBsAtIv8UkQsLHcALLAQe9B6YO4iInOL+cnDOAq/1zt8IHJF4xLk19ThwJVAb+AH4V5HtXACcCfQAHgHeBP4MNALaA9cVKlsPSMS5WroJeFNECm4DPY+TLDsBLb1lniry3nigCXA78BCQ7I2rrjfOwn3R/Akn+TUDOgI3l/Bd3OD9HC28MTxRQtmijtqPqm7AOXsGqKmq/UUkFpgDTAfq4Hzvk+XIWzTXA8/hXDH+SOm+kxre5cOASYV+Ny/h/F3OxvneHgE8ItIA+Ap41rv8YeATEal9vA8qIg2BC4GN3vkY7/Y/Kqb4f4AB3unzgdmqmnG8fXi3Ww2YC8zGuUppiXNFUlrXAYOBmsCHwEXebRZcqf0J5+8AMA1weffRGRgI/N8J7KvqUFV7VYAXkITzw03G+fF+AdT1rgsF7gJ+wrkdtRO4qZhtbAHOP85+puEcKHoDP+P8h9oDROMcoG72lpsFDCv0vhAgC2jinVegV6H1S4FRheZfBl71TvfzfqbYQuv/AzwJCJAJtCi0riewudB784CoQuvH4iS6lsf4Dv5caH4cMLXQtpKLlB1eaP4iYNMxvrebgR9LuZ+m3u8nzDs/BPihyPbeAMYU+pu8W2hdab6T7ILte5ftxUnYId51ZxTzGUYB7xVZ9nVxv6VCnzEDSPd+nnk4yQ+goXdZm2LeNwjI907PAZ4/gf8H1wG/lfTbLTRf3N/z1iLv+RG40Ts9oODvi3OCkQtEF9n3t6f6f7kyvuyKooJQ1bWqerOqNsQ5Gz8NeNW7zq2qk1S1F86B/TngbRFJOoX9/YhzRj4amKGq2UWKNAEmeG9RpAAHcQ5gDQqV2VNoOruY+cIVn4dUNbPQ/FbvZ6wNxABLC+1rtnd5gX2qmlNo/kWcM9v/icgfIvJokdgL30LLKhJHUduLiam0SrufJkD3gs/n/Yw34FwVFBdHab6TA6rqKmb/iTj1I8XVNzUBrikSR2+gfgmf8XJVLagXa+PdPsAhwHOM99YH9hfEeZztF9XoGLGX1vYi89M5fGV7PYevJpoA4cCuQt/FGzhXfKYISxQVkKquwzl7al/MumxVnYTzH7Vt0fUn6H2c2zjF1XdsB+5Q1ZqFXtGquuAk91XLewumQGOcK6P9OEmlXaH91FCnArXAEV0cq3M//CFVbQ5cinNb7ryTjKtRMTGVte3A90W+yzhVvbNQmcKfsTTfybHsx7m12OIYcbxXJI5YVX3+eBtV1e9xfpMveeczca5Irymm+J84fLtoLnBBkb99SbYDzY+xLhMngRaoV0yZot1hfwT08946u4LDiWI7zhVFYqHvorqqtsMcxRJFBeCt6HzI+2NGRBrhnAUt9M7f7624ixaRMBG5Cede9m+nuOvXcC7H5xezbirwmIi088ZQQ0SKOyiciL+I09T3HOBi4CNV9QBvAa+ISB3vvhqIyAXH2oiIXCwiLb11NamAG+fs9mTcJSINRSQe5+rq3ye5nZLMAE4XkaEiEu59nXWsK8KT+U6KvPdtYLyInCYioSLSU0QicU4MLhGRC7zLo7y/q4al/ByvAgNE5Azv/KPATSJyr4hUE5Fa3srmnjit9gDewzkof+L9nYeISII4DRsuOsZ3Vd/7m4/0bre7d90ynDqHeBGpB9xfiu9jH/Ad8A7Orbu13uW7cFpsvSxO890QcRpj9C3ld1GlWKKoGNKB7sAv4rR6WQiswjnbB+e2wss4tzr249RXXKWqf5zKTtVppTJPVY8alERVPwNeAD4UkTRvPBeewu5241wF7QQ+wKkbWOddNwrnVtJC777mAiU979DKWyYD56x2sqp+e5JxTcc5YPyBc8vj2ZKLnzhVTcepKL0W5/PvxvluI0t424l+J4U9DKwEFuPcMnwBCFHV7UBBI4V9OAfwkZTyOOA96L6Lt1Lde/vyApwGD7twbt11Bnqr6u/eMrk4FdrrcOor0oBFOLewfilmH+k4Jy+X4HxPvwPnele/h9P8dgvO36y0SX26N4bpRZbfCEQAa3B+mx9zYrfJqgwp5hhhTJUgIluA/1PVuYGOxZiKzK4ojDHGlMgShTHGmBLZrSdjjDElsisKY4wxJQq6jrMSExO1adOmgQ7DGGOCytKlS/er6nG7aylO0CWKpk2bsmTJkkCHYYwxQUVEtp7se+3WkzHGmBJZojDGGFMiSxTGGGNKZInCGGNMiSxRGGOMKZElCmOMMSXyW6IQkbfFGc931THWi4i8JiIbxRl3uIu/YjHGGHPy/PkcxTRgIsUPigNOl9WtvK/uwBTvv8YYUyWoKh4Ft0fxaMHLmVdV3B4lK89NvtuDRw+XLyirvmm88870/vRcoiJCffvJz3OfUpx+SxSqOl9EmpZQ5DKccYIVp8/9miJS3zugiDGmEil8QHR7FLcqbreSmp1PjsvN/oxc8t2K2+PB5XbKuDzKzpRswkNDUHDWeRSPRzmYmc/+jFxqxYTj8hwu7/Yov247RLPEWO/BFt/+Cg68boVNezOICg8hLtI5BCqgCuodIE/VeRXEfswyHC4HWmhZkfcUmsdbJiO38Ei2/nPo2+3k7ck6pW0E8snsBhw5vm2yd9lRiUJEbgduB2jcuHG5BGdMVaKq5Lo85OS7OZCZR77bOWDnuz2+A/betFznQK3O8kOZecxbt5f42Ag8qrg94ClIAt6Ddmaui9TsfFye8u18dOuB4x8YM3Jhf0ZeOURzfGEhQkiIECpCiOBMhwghIuS5PGTkumheO5aQgvUiSKHpEAERQbzzmbkuMvNcNE+MY9O+XH7+tdgagNLHV0af069U9U3gTYCuXbtad7fGHIPL7WHbwSzW7Epjw+50NuzJICYilHyP4nJ7yHN5WLMrjRARXJ7DySAt5+TPbpMPZZeqnIj3gCjiOzCm57ioHhVGo/gYNu3LoEfzBMJCQggLEUJDnXJbDmTRuVHNI5aFhoSQk+8mKjyU2tUivcvE929Ovpu61aN8B9sQKTjw4jsge1SpER2OM6IuiDdGQXzxFvxbUhkpWIGz3LcM78G7yHuQw++LDg8lLLRsq4rXrNnHr7/u4s9/7giA3nIWW5/sT7NmY056m4FMFDs4cmD7ht5lxlRqqkpatosdKdmk5+SzOy2H3HwPbu+Z+vrd6VSPDmfzvkxyXW7CQ0NIPpTNgcxcEmIjffennfva+KbdHmVXas4pxVYtKoycfDciQsvacYSHOgdYgOx8D/Gx4STVq05MRCjhoSGEhAhR4aF0bVKLUO9B2vfyHpwT4yKJDHPKGv/Jysrn2Wfn8+KLCwgNFXr0aEjLlvGICE2b1jylbQcyUXwB3C0iH+JUYqda/YQJRm6Psv1gFnvTc/l+w17CQkLIcblZsuUQiXER/LL5IPVrRLN2VxpxkWGndG96T1puqcpFh4fSLDGWFnXiaJoQQ8s6cc6ZeqgQHirUrR5FQmykMx8SQmioEBMeagfzIDVr1u/cdddMNm9OAWDYsDNJSIgus+37LVGIyL+AfkCiiCQDY4BwAFWdCswELsIZQD4LuMVfsRhzMjJyXazYnkJ2vptcl4eNezNIyconIzefbQez2Lw/k1yXh5Ss/ONuq6BM0SQRItCrZSJ703JpU78a1aLCiAgNZX9GLh0b1sDlUerXiKJ6dDgoxMdGEB4actRtlBARQkIgNESoWy3KDvhVxI4dadx//9d8/PEaADp2rMvUqYPp2bPRcd55YvzZ6um646xX4C5/7d+Y0sjMdbF06yH2peeS6/Lw+950UrLymbd2z0ndt+/VMoE/9mUytGcTosNDcXuUZomxVIsKJyEughrR4dSIDie8jO9Lm6rprrtm8vnn64mJCWfs2H7cd18PwsLK/rcVFJXZxpSFrDwXXy7fycdLkwkNEX7dlkKey3Pc9yXERtCtWTwRYSEcysqnZe04GsdHExcVTtOEGJolxlIrJsLO4k25cLk8vmTwwgvnEx4eyssvD6Rx4xp+26clClNp7U3L4evVu5m5cjc7U7NLbDKZGBdB39PrUC0qjJSsPLo1S6BDgxp0aOi//3zGnIjU1ByeeOIbNmw4yOzZNyAitG6dyEcfXeP3fVuiMJXCnrQc0rLz+d+aPSzdeohv1u09ZtneLRO5onMDWtSJ47SaUdSpFlWOkRpzYlSVjz5aw/33z2bXrgxCQ4Vly3bTuXP9covBEoUJKi63h69X7+HHjfvJyHWxdMtBdh6nSehZTWtxdotErj6zIY3iY8opUmNO3aZNB7n77lnMnr0RgJ49GzJ16sV07Fi3XOOwRGEqvHy3h2k/bWH6om1s3p95zHJxkWE0rBVNnsvDzb2a0r9NHRrWssRggtNLLy3gySe/JSfHRc2aUbzwwvn83/91CUhdmCUKU+Hkuz2s2pHKU5+vZuWO1GLLdGpUky6Na3FGoxqc0bAmTRJifE/PGlMZZGXlk5PjYujQjrz00kDq1IkNWCyWKEyFsXZXGpe8/uMx+wVqUDOaxy5qw/lJdYkKDy22jDHBat++TNavP0Dv3k5/dqNG9aJfv6b06dMkwJFZojAB4vEov2w+yNy1e/h69e5i+wtqmhDD0J5NubzTaSTERQYgSmP8z+NR3n77Nx55ZA5hYSGsW3c38fHRREaGVYgkAZYoTDlRVVbtSCP5UBZ3fvBriWXfHHomA9vVK6fIjAmcVav2Mnz4DH76yelIe8CA5mRl5RMfX3bdb5QFSxTGL9Jz8pmzZg9Lth7is193kJ1/7IFTbj67Kc1rx3LZGQ2oERNejlEaExiZmXmMHfs948cvxOXyULduLK++OoghQ9pVyLo2SxSmzKRm53Pbu0tYtPlgieUuOeM0PB5l/JAziAyzugZT9Vx99UfMnr0RERgxoivPPXceNWtW3Od5LFGYk6aqfLQkmc0HMpny3aZiy9SrHkW/1rUZclYjOjWqWSHPlowpb6NG9WLPngymTBlM9+4NAx3OcVmiMCfE7VEmfbuRfy7YwoHM4kcHa5YYyyd3nk18bEQ5R2dMxeNyeXj99V/YsiWFCRMuBKBfv6YsWXJ70PQPZonCHJfHo6zamcr4ORv4bv2+o9ZHhIZw3/mt6NYsnq5NatlVgzFeixbt4I47ZrBs2W4Abr/9TNq1qwMQNEkCLFGYEqzakcrbP27m09+OHniwQ4MavHhNR06vUy2ofvDGlIeUlBwef3weU6cuQRWaNKnBxIkX+ZJEsLFEYY7g9ihfr97NiGKasPZvU4dzW9fmzz2a2FWDMcfw4YeruP/+2ezZk0lYWAgPPdSTJ5/sQ2wQ34q1RGF8/jZrLW98/8cRy+JjI7iwfT1GXtCamjHB+0M3prz873+b2LMnk169GjFlymA6dCjfDvz8wRKF4cNF23j005VHLGt3WnXu6d+SQe3LrytjY4JRbq6LHTvSad68FgDjxg3gnHMac9NNnSrNbVlLFFXYP37czDMz1hyxLD42gq/v70PtatZlhjHH8803m7nzzq8ICRGWLx9OREQoiYkx3HJL50CHVqYsUVQhqsqSrYf4+w9/8PXqPUesqxkTzrOXt+fijqcFKDpjgseePRk8/PAc3n9/BQBt2iSSnJzmu6qobCxRVBG/bTvEFZMXHLX89LpxTPnzmbSoHReAqIwJLh6P8tZbS3n00XmkpOQQFRXGE0+cw8iRvYiIqLy9DFiiqAL++9sO7v/3Mt98VHgIvVokMurCNpxet1oAIzMmuFxxxb/54ov1AFxwQQsmTbqIFi3iAxyV/1miqOS+WL7ziCQx/bbunN0iMYARGRO8rryyDYsW7WDChEFcc03bKtNM3BJFJaWqvDH/D56ftc637NuH+9EsMXCjZBkTbL74Yj3JyWmMGHEWADfeeAZXXplEtSrW2MMSRSVUXHPXabecZUnCmFLati2Ve++dxeefrycyMpRBg1rSvLnTPU1VSxJgiaJSUVWGvLGQRVsOd/Ndu1ok7w3rRpt61QMYmTHBIT/fzWuv/cKYMd+RmZlPtWoRPPtsf5o0qRHo0ALKEkUl4fYoLR6fecSyXx4/j7rVK24f98ZUJAsXJnPHHTNYscJpOn7NNW155ZULaNDATrIsUVQC/1u9m9vfW+qbv6zTabzyp8rzVKgx5eHJJ79lxYo9NGtWk4kTL+Kii1oFOqQKwxJFEFu3O41r31xISla+b9l5beow4drK9VSoMf6gqqSn51G9ulPnMHHihbz77nJGj+5DjA3JewRLFEFq4950Br36g28+RODTEb3o1KhmAKMyJjisX7+fESNmIgJz5gxFRGjdOpHnnjsv0KFVSJYogtDt7y7hf2sOd8Ex/k9ncGWXij+cojGBlpPj4m9/+4Hnn/+JvDw3CQnRbNmSQrNmlbPrjbJiiSLIFE0SX97dmw4Nq3aLDGNKY86cTYwYMZONG51Wgbfe2olx4waQkBAT4MgqPr8mChEZBEwAQoG/q+rzRdY3Bv4J1PSWeVRVZx61IYOqMuaL1UckiZ8e7U+DmtEBjMqYik9VGTbsC955x+mhoG3b2kydOphzzmkS4MiCh98ShYiEApOAAUAysFhEvlDVwv1aPwH8R1WniEhbYCbQ1F8xBav5G/Zx49uLjlj2x18vslZNxpSCiNC0aU2io8N46qm+PPhgz0rdgZ8/+POKohuwUVX/ABCRD4HLgMKJQoGCRso1gJ1+jCcoTftpM09/eeSYEQse7W9JwpgSLFu2m1270rnwQqeJ66hRvRg6tKPVRZwkfyaKBsD2QvPJQPciZZ4G/ici9wCxwPnFbUhEbgduB2jcuHGZB1pRzVix84gkMe+hvtYduDElSE/PZcyY75gw4RcSEqJZt+5u4uOjiYwMsyRxCkICvP/rgGmq2hC4CHhPRI6KSVXfVNWuqtq1du3a5R5kIGTkurh7+m+++fkjz7UkYcwxqCqffbaWtm0n88orCwG4/voOhIcH+hBXOfjzimIH0KjQfEPvssKGAYMAVPVnEYkCEoG9foyrQst3e+j34nfsSMn2LbOuOIw5tq1bU7j77lnMmLEBgK5dT+ONNy6mSxcb772s+DNRLAZaiUgznARxLXB9kTLbgPOAaSKSBEQB+/wYU4XXavSsI+afGJxkScKYY1BVrrrqPyxduovq1SP561/7M3x4V0JD7UqiLPktUaiqS0TuBr7Gafr6tqquFpGxwBJV/QJ4CHhLRB7Aqdi+WVXVXzFVZC63h6umHB6qtHF8DN+P7FdlBkYx5kR4PEpIiCAivPTSQKZOXcIrr1xA/fo2YqM/SLAdl7t27apLliwJdBhlSlVp/eRs8lweANrUq8bs+/sEOCpjKp4DB7J49NG5ALz11qUBjia4iMhSVe16Mu+167MAc3uUv85c60sSfU6vzaz7zglwVMZULKrKP/+5jDZtJvH3v//Gu++uIDk5LdBhVRnWhUeADf3HLyzYdACAalFhvHtrtwBHZEzFsnbtPu688yu+/34rAP36NWXKlME0bGjjRJQXSxQB9Pq8331JonpUGF/da1cSxhRQVZ566lteeOEn8vM9JCbG8PLLAxk6tKPV3ZUzSxQBkpKVx8tzNvjmf31yAGHWUsMYHxFhx4508vM93HZbF55//nzi461vs0CwRBEAeS4PncbO8c0vHn2+JQljgJ0709m/P4uOHesCMG7cAIYN60yvXlWnR4aKyI5OAfD0l6t901Nu6ELtapEBjMaYwHO7PUycuIikpElce+3H5OW5AUhMjLEkUQHYFUU523ogk+m/bAOga5NaXNjBnh41Vduvv+7ijjtmsGSJ0ydonz5NSEvLJTHRxomoKEqVKEQkAmisqhv9HE+ltu1AFn1f/M43/+4wa+Fkqq60tFyefPIbJk5cjMejNGxYnddeG8Tll7exyuoK5riJQkQGA+OBCKCZiHQCxqjqFf4OrjLxeJQ+L37rm590fRdiIuyCzlRNqkqfPu+wfPkeQkOFBx/swdNP96Oa3YatkEpTRzEWp3vwFABVXQa09GdQlVHzxw8P3Df9tu4M7mi3nEzVJSI88EAPunVrwJIlt/PyyxdYkqjASnNKm6+qKUUuBYOr348AKzw6XbvTqnN2i8QARmNM+cvLczN+/M+EhgojR/YC4MYbz+DPf+5oHfgFgdIkirUi8icgxNsT7L3AQv+GVXks2LSf+RsOd4hrD9WZquaHH7YyfPhXrFmzj8jIUG688Qzq1o1DRAgNtbqIYFCaVH43cCbgAT4FcoH7/BlUZbEiOYXr3/rFN7927KAARmNM+dq/P4tbb/2cPn2msWbNPlq1imfGjOupW9cG4Ao2pbmiuEBVRwGjChaIyJU4ScMcw970HC6d+JNv/rMRZxNtA7qbKkBVmTZtGSNHzuHAgWwiIkJ57LHePPpob6KirAFHMCrNX+0Jjk4Ko4tZZgrp9tw83/TcB/vSso6dRZmq4/33V3LgQDb9+zdj8uSLaN3a6uWC2TEThYhcgDNMaQMRGV9oVXWc21DmGDJyXb7pUYPaWJIwlV5WVj6pqTnUr18NEWHy5ItYvHgnN9zQwZ6JqARKuqLYC6wCcoDVhZanA4/6M6hglu/20H7M1775O/o0D2A0xvjfrFm/c9ddM2nevBZz5gxFRGjdOtGuIiqRYyYKVf0N+E1EPlDVnHKMKWgt357CZZMO10tc160xISF2NmUqpx070rj//q/5+OM1AFSrFsmBA9nW9UYlVJo6igYi8hzQFogqWKiqp/stqiB1w98Pt3AaNagNd/ZrEcBojPEPt9vDpEmLeeKJb0hPzyM2NpyxY8/l3nu7ExZmz0RURqVJFNOAZ4GXgAuBW7AH7o4ybNpiX93ElBu6WGd/plLyeJS+fafx00/bAbj88jZMmDCIxo1rBDgy40+lSf8xqvo1gKpuUtUncBKG8UrJymPeur2++UHt6wUwGmP8JyREGDiwBY0aVefzz6/ls8+GWJKoAkpzRZErIiHAJhEZDuwAqvk3rOAy8uMVvul1zwyyVh6m0lBV/vOf1YSFhXDVVW0BGDWqFw8+2JO4uIgAR2fKS2kSxQNALE7XHc8BNYBb/RlUMNmdmsOcNXsAuLJzA6LC7aE6Uzls2nSQESNm8r//baJ27Rj6929GrVrRREaGEWn991Upx00UqlpQQ5sODAUQkQb+DCpY5LrcnPvSd775F685I3DBGFNGcnNdvPjiAp577gdyclzUqhXFc8/1p0aNqOO/2VRKJSYKETkLaAD8qKr7RaQdTlce/YGG5RBfhXbey9+Tne8M2Tjlhi6EWlNYE+S++24Ld975FevW7Qdg6NCOvPTSQOrUiQ1wZCaQjlmZLSJ/Az4AbgBmi8jTwLfAcqDKN439v38uIflQNgD/vLWbtXIyQc/t9jBihJMkWrdO4JtvbuTdd6+wJGFKvKK4DDhDVbNFJB7YDnRQ1T/KJ7SKa/aq3cxd69RLnNemDn1Prx3giIw5OR6PkpPjIiYmnNDQEKZMGcz8+Vt55JFeREZaB37GUdIvIUdVswFU9aCIbLAk4Xh17gbf9N9v6hrASIw5eStX7mH48K9o0yaBf/zjMgD69m1K375NAxuYqXBKShTNRaSgh1jBGS/b12Osql7p18gqsHW70wF47EIbBN4En8zMPMaO/Z7x4xficnnYvPkQhw5lU6tWdKBDMxVUSYniqiLzE/0ZSLBYtSPVN32R1UuYIPPll+u5++5ZbNuWigiMGNGV5547j5o1rUWTObaSOgWcd6x1VdmdHyz1TTeKt87PTHBwuTwMGfIxn366FoBOnerxxhsX062btXQ3x2e1VSfgw0Xb2H7Qael033mtAhyNMaUXFhZCjRqRxMVF8Mwz53L33d2sAz9Tan79pYjIIBFZLyIbRaTYMSxE5E8iskZEVovIdH/GcypUlUc/Xeke4bZAAAAgAElEQVSbv9cShangfvklmV9+SfbNv/jiANauvYv77+9hScKckFJfUYhIpKrmnkD5UGASMABIBhaLyBequqZQmVbAY0AvVT0kInVKH3r5+suXvrBZ+Nh59nCdqbBSUnJ47LG5vPHGUtq0SWTZsuFERISSkGC3Ss3JOe5phYh0E5GVwO/e+TNE5PVSbLsbsFFV/1DVPOBDnGczCrsNmKSqhwBUdS8VkMvtYdqCLQBEh4dSz7oyMBWQqjJ9+kratJnI1KlLCQ0N4dJLW+N228jF5tSU5oriNeBi4L8AqrpcRM4txfsa4DykVyAZ6F6kzOkAIvITEAo8raqzS7HtcvXlip2+6VV/uSCAkRhTvN9/P8CIETOZO9d51KlXr0ZMnXox7dtX2It0E0RKkyhCVHVrkecF3GW4/1ZAP5y+o+aLSAdVTSlcSERuB24HaNy4cRntunRUlQf+vRyAetWj7JaTqXDy89307/8uyclpxMdHM27c+dxyS2cbhteUmdIkiu0i0g1Qb73DPcCG47wHnHErGhWab+hdVlgy8Iuq5gObRWQDTuJYXLiQqr4JvAnQtWvXch1dr9ljM33TYy9rV567NqZEqoqIEB4eynPP9efbb7cwbtz51K5tfTOZslWapg93Ag8CjYE9QA/vsuNZDLQSkWYiEgFcC3xRpMx/ca4mEJFEnFtRFaabkPcXbvVNx0aEMrCdjVxnAm/PngyGDv2MZ5+d71t2441n8M47l1mSMH5RmisKl6pee6IbVlWXiNwNfI1T//C2qq4WkbHAElX9wrtuoIiswbmdNVJVD5zovvylYEAigNVjBwUwEmOcDvzeemspjz46j5SUHGrWjOL++3tQrZqNImT8qzSJYrGIrAf+DXyqquml3biqzgRmFln2VKFpxblaebC02ywv+W4P32/YB8CcB/oEOBpT1S1fvpvhw79i4ULnuYhBg1oyadJFliRMuSjNCHctRORsnFtHfxGRZcCHqvqh36MLoNOfmOWbblknLoCRmKosP9/NY4/N49VXF+J2K/XrxzFhwiCuvrqtdUhpyk2pHs9U1QWqei/QBUjDGdCo0srOc6PeKvNeLRPsP6QJmLCwEH77bTcej3LPPd1Yu/Yurrmmnf0mTbk67hWFiMThPCh3LZAEfA6c7ee4AqrXC9/4pt8fVvTRD2P8a9u2VNxuD82a1UJEmDp1MKmpuXTtelqgQzNVVGnqKFYBXwLjVPUHP8cTcL9uO8TBzDwAbunV1M7cTLnJz3czYcIvjBnzHT17NmTOnKGICK1aJQQ6NFPFlSZRNFfVKtMHwF0f/OqbHnOJPTdhysfPP29n+PCvWLHCaWkXHx9NVlY+sbERAY7MmBIShYi8rKoPAZ+IyFEPuVXGEe7ScvLZlZoDwNs32xCnxv8OHcrm0Ufn8uabzglKs2Y1mTTpIi680HonNhVHSVcU//b+W2VGtrv3X7/5ps9tbX3kGP/KzXXRqdMbbNuWSnh4CCNHns3o0X2IiQkPdGjGHKGkEe4WeSeTVPWIZOF9kK7SjYBXUDdxccf6Vjdh/C4yMoxhwzozb95mpkwZTNu2tQMdkjHFKk3z2FuLWTasrAMJtKw8FyuSnfGwr+na6DiljTlxOTkuxoz5lunTDw+A9fjj5/DddzdZkjAVWkl1FENwmsQ2E5FPC62qBqQU/67g9d7Ph/t16tE8PoCRmMpozpxNjBgxk40bD1KnTixXXNGG6OhwG2nOBIWS6igWAQdwen2dVGh5OvBbse8IYn+btQ6AC9vXIzIsNMDRmMpi9+4MHnzwa/71r1UAtGtXm6lTLyY62uohTPAoqY5iM7AZmFt+4QTG16t3+6aHnGW3ncypc7s9vPHGUh5/fB6pqblER4cxZkxfHnigJxERdiJigktJt56+V9W+InIIKNw8VnD686s092fueG+pb7qftXYyZcDtVl5/fRGpqblcdFErJk68kGbNagU6LGNOSkm3ngqGO00sj0ACJflQlm/6b1d2CGAkJtilp+fidis1a0YRERHKW29dwp49GVx5ZZK1ojNB7Zg1aYWexm4EhKqqG+gJ3AFUmtFRXpnzu296iLV2MidBVfn007UkJU3ioYe+9i3v3bsxV11lvbya4FeaJhf/xRkGtQXwDs5QpdP9GlU5+uRXp3//Xi0TbIxhc8K2bEnh0ks/5Kqr/sOOHemsWrWPnBxXoMMypkyVJlF4vGNaXwm8rqoPAA38G1b5WLr1oG/63v7WZYIpvfx8Ny+88CNt205ixowNVK8eycSJF7Jgwa1ERZWmCzVjgkephkIVkWuAocDl3mWVom3fvxZt9013a1Zp6uaNn2Vl5dOjx99ZuXIvANde257x4wdSv361AEdmjH+UJlHcCozA6Wb8DxFpBvzLv2GVj5SsfACG921h95FNqcXEhNO162lkZeUzefJgBg5sEeiQjPGr0gyFukpE7gVaikgbYKOqPuf/0Pxv7lqnS+ckOxM0JVBV3n13OS1axNO7d2MAXnnlAiIiQu3BOVMllGaEu3OA94AdOM9Q1BORoar6k7+D86fCzWKbJlSaRlymjK1du4877/yK77/fSlJSIsuWDSciIpQaNaICHZox5aY0t55eAS5S1TUAIpKEkziCesCGgi7Fa8WE06FBjQBHYyqa7Ox8nnvuB8aN+4n8fA+1a8fw2GO9CQ+3vplM1VOaRBFRkCQAVHWtiAT1sFt70nL4dZvTr+HTl7azZrHmCLNnb+Suu2byxx+HALjtti48//z5xMdHBzgyYwKjNIniVxGZCrzvnb+BIO8UcObKXQBEhIZwWadK0dLXlJGMjDyGDv2M/fuzaN++DlOnDqZXr8aBDsuYgCpNohgO3As84p3/AXjdbxGVg2e/WgvAPf1bBjgSUxG43R48HiU8PJS4uAgmTBhEcnIaDzzQg/Bw68DPmBIThYh0AFoAn6nquPIJyb+Wbj2E2+P0cTiofb0AR2MCbenSndxxxwwuu6w1Tz7ZF4Drr7c+v4wp7Jg1cyLyOE73HTcAc0SkuJHugs5HSw4/ZNeqrjWLrarS0nK5775ZdOv2d5Yu3cV7760gP98d6LCMqZBKuqK4AeioqpkiUhuYCbxdPmH5z7frnadpbzunWYAjMYGgqnz88Rruu282u3ZlEBoqPPhgD/7yl3PtNpMxx1BSoshV1UwAVd0nIkHfLlBV2ZOWC8AN3ZsEOBpT3tLTcxky5GNmzdoIQPfuDZg69WI6dbJbkMaUpKRE0bzQWNkCtCg8draqXunXyPzg7z9s9k03SYgJYCQmEOLiIsjNdVOjRiTPP38+t99+pjWNNqYUSkoUVxWZn+jPQMrD2t1pAHRoUMP6dqoi5s/fSv36cbRqlYCI8PbblxIVFUbdunGBDs2YoFHSmNnzyjOQ8vDTxv2A0wmgqdz278/ikUfm8M47yzjvvGbMmTMUEaFJk5qBDs2YoFNlOs5PzcrnQEYeANVsvIBKy+NRpk1bxsiRczh4MJuIiFDOOacxbrcSFmZXkcacDL9WUIvIIBFZLyIbReTREspdJSIqIn7rP2ro27/g8igJsRH0bJHgr92YAFq9ei/9+k1j2LAvOHgwm/POa8bKlXcyZkw/wsKCvi2GMQFT6lNrEYlU1dwTKB8KTAIGAMnAYhH5onC/Ud5y1YD7gF9Ku+0TtTcthxXJqQBMu6Ub4aF20KhsUlNz6NHjH2Rk5FGnTizjxw/k+us7WF2UMWXguEdMEekmIiuB373zZ4hIabrw6IYzdsUfqpoHfAhcVky5Z4AXgJzSh31i1uxK8013aGg9xVYmqs5T9jVqRDFqVC+GDz+Tdevu4oYbOlqSMKaMlObU+jXgYuAAgKouB84txfsaANsLzSdTZKxtEekCNFLVr0rakIjcLiJLRGTJvn37SrHrI23cmwHA4A71T/i9pmLasSONq6/+D++/v8K3bPToc5gy5WJq1bJeXo0pS6VJFCGqurXIslPu68D7AN944KHjlVXVN1W1q6p2rV279gnvK/lQNgBt6lmXHcHO5fIwYcJC2rSZxCefrGXMmO9wuz0AdgVhjJ+Upo5iu4h0A9Rb73APsKEU79sBNCo039C7rEA1oD3wnfc/eD3gCxG5VFWXlCb40vp1mzOuQP2adqYZzBYv3sHw4V/x669ON/GXX96G114bRKjVORnjV6VJFHfi3H5qDOwB5nqXHc9ioJWINMNJENcC1xesVNVUILFgXkS+Ax4u6yQB+CqyE+KCerylKiszM49Ro+YyefJiVKFx4xq8/vqFXHpp60CHZkyVcNxEoap7cQ7yJ0RVXSJyN/A1EAq8raqrRWQssERVvzjhaE/C+t3pvuneLRNLKGkqqrCwEObO/YOQEOHBB3syZkxfYmMt6RtTXo6bKETkLUCLLlfV24/3XlWdidPrbOFlTx2jbL/jbe9kPDPDaY3b9/Ta1iw2iGzadJCaNaNISIghMjKM9967gqioMDp0qBvo0Iypckpz5JwLzPO+fgLqAKV+niLQth7MBKwTwGCRm+vi2Wfn0779FEaNmutbftZZDSxJGBMgpbn19O/C8yLyHvCj3yIqQ9+s28P2g06Lpzusf6cK77vvtnDnnV+xbp3TJ5fL5cHt9lhltTEBdjKdHjUDguLUbsLc333TDazFU4W1d28mI0fO4d13lwPQunUCU6YM5txzbXApYyqC0tRRHOJwHUUIcBA4Zr9NFcX63eks97Z2en9Y9wBHY45l//4skpImcfBgNpGRoYwefQ6PPNKLyEjruNGYiqLE/43iPOBwBoeff/BoQZ8JFdx7C7cAUCsmnN6trLVTRZWYGMNll7UmOTmNyZMH07JlfKBDMsYUUWKiUFUVkZmq2r68AiorBU9jn58UFHfJqozMzDzGjv2ewYNPp08fZzjayZMHExkZak9WG1NBlaaWcJmIdPZ7JGVsrbcjwM6NawU4ElPgyy/X07btZMaNW8CIEV/h8TgXp1FRYZYkjKnAjnlFISJhquoCOuN0Eb4JyMQZP1tVtUs5xXhS9qU7LXib144NcCRm+/ZU7rtvNp99tg6Azp3r8cYbF9t41cYEiZJuPS0CugCXllMsZcp7smqJIoBcLg+vvfYLTz31LZmZ+cTFRfDss+dy113dbCAhY4JISYlCAFR1UznFUmbyXB7fdJy1ngmYtLRc/va3H8nMzOeqq5J49dVBNGxYPdBhGWNOUElH0doi8uCxVqrqeD/EUyZ2pWb7pmMiLFGUp5SUHKKjw4iMDCM+Ppo33riYyMhQBg8+PdChGWNOUknX/6FAHE534MW9KqzFW5xuxc9sYhXZ5UVVmT59Ja1bT2TcuJ98y6+8MsmShDFBrqTT7V2qOrbcIilDy7enAFArxnoYLQ8bNhxgxIivmDdvMwDz529DVa0lkzGVxHHrKILRTxudvoLObpEQ4Egqt5wcFy+88CN//euP5OW5iY+P5sUXB3DzzZ0sSRhTiZSUKM4rtyjKUJ7Lwx/7nR5j255mFaf+snt3Bn36vMPvvx8E4OabO/HiiwNITLReeo2pbI6ZKFT1YHkGUlb+8aNz+yMiNITuzaw7CH+pWzeWRo1qEBYWwpQpg+nbt2mgQzLG+EmlaxJ0IMN50K5dg+p2+6MMeTzKW28t5dxzm3H66QmICNOnX0mtWtFERIQGOjxjjB9Vuqee/r14OwBDezQJcCSVx/Llu+nV622GD/+KESO+oqBfyLp14yxJGFMFVKorClUlI88FQNNEeyL7VGVk5PH009/x6qsLcbuV006rxvDhXQMdljGmnFWqRJHr8lDQCXrnRjUDG0yQ++9/13HPPbNITk4jJES4555uPPtsf6pXjwx0aMaYclapEsV36/cCTrcdVj9x8nbsSOPaaz8mN9fNmWfWZ+rUi+na9bRAh2WMCZBKlSgKmsX2amnPT5yo/Hw3YWEhiAgNGlTnuef6ExERyogRZ9mY1cZUcZXqCLAzxenjqf1pNQIcSXBZsGA7Z575Ju+/v8K37KGHzuaee7pbkjDGVK5E8f7CbQC0qBMX4EiCw8GD2dxxx5f06vU2K1fuZfLkJQTJSLfGmHJUaW49HczM8013tc4AS6SqvP/+Ch566H/s25dFeHgIjzzSi9Gjz7G6HWPMUSpNovhy+U4A6laPpE71qABHU3Ht2ZPBddd9wrffbgGgb98mTJkymKSk2oENzBhTYVWaRLF6ZyoAw3o3C3AkFVvNmlHs2pVBYmIML700gBtvPMOuIowxJao0ieKPfU6Lpxa1rX6iqDlzNtGlS30SEmKIjAzjo4+uoX79OBISrAM/Y8zxVYrK7NSsfH7ddojQEKF9A2vxVGDXrnSuu+4TBg58n1Gj5vqWt29fx5KEMabUKsUVxeItB/EodGtai7pWP4Hb7eGNN5by2GPzSEvLJTo6jNatE2wwIWPMSakUieL3vRkAtLRmsfz66y6GD5/B4sVO5f7gwa2YOPEimja1Lk2MMSenUiSKNbvSAOhYxW87bdmSQrdub+F2Kw0aVOO11y7kiiva2FWEMeaU+DVRiMggYAIQCvxdVZ8vsv5B4P8AF7APuFVVt57oftZ5E0WrulX7iqJp05rccksnqlWL5C9/6Ue1ataBnzHm1PmtMltEQoFJwIVAW+A6EWlbpNhvQFdV7Qh8DIw70f2kZuf7bj01S6xaiWLLlhQuueRffP/9Ft+yN9+8hPHjL7AkYYwpM/68ougGbFTVPwBE5EPgMmBNQQFV/bZQ+YXAn090J7tTc3zT8bERJxtrUMnPdzN+/M/85S/fk53tYv/+LH7+eRiA3WYyxpQ5fyaKBsD2QvPJQPcSyg8DZhW3QkRuB24HaNy48RHr1uxyHrQ7p1XiyUcaRH78cRvDh89g9ep9AFx7bXvGjx8Y4KiMMZVZhajMFpE/A12BvsWtV9U3gTcBunbtekSvdbtTnTGy29av7t8gA+zQoWxGjpzDP/7xGwAtWtRi8uTBDBzYIsCRGWMqO38mih1Ao0LzDb3LjiAi5wOjgb6qmnuiO8l1uQGIDK/cYzd7PMrnn68nPDyERx/tzWOP9SY6OjzQYRljqgB/JorFQCsRaYaTIK4Fri9cQEQ6A28Ag1R178nsJC3bGSM7NqLyJYp16/bTrFlNIiPDSEiI4YMPrqRx4xq0aVM1brMZYyoGv7V6UlUXcDfwNbAW+I+qrhaRsSJyqbfYi0Ac8JGILBORL050P9sPZQFQr0bleSI7Kyuf0aPn0bHjFMaN+8m3fODAFpYkjDHlzq91FKo6E5hZZNlThabPP5Xtu9wefvnjAACdG1WOMShmz97IiBFfsXlzCgD792cFOCJjTFVXISqzT9a2g1mk5bg4rUYUjYO8k7udO9O5//7ZfPSR03q4Q4c6TJ16MWef3eg47zTGGP8K6kRR8KBdfFxwPz+xYcMBunZ9k/T0PGJiwnn66b7cf38Pwit5Bb0xJjgEdaJYkezcnqkdF9xPIbdqFc9ZZzUgNjac11+/kCZNrAM/Y0zFEdSJYv1u54ri8s4NAhzJiUlLy+Wpp75lxIizOP30BESEL764ltgq8mS5MSa4BHWi2LTPSRRJQfKwnary8cdruO++2ezalcG6dfuZPdvptcSShDGmograRJGd52bLAWf407rVKn7T2D/+OMTdd89k1qyNAPTo0ZAXXjilRl/GGFMugjZRHMjMRRUS4yKpEVNxn1DOy3Pz0ksLeOaZ+eTkuKhZM4rnnz+P2247k5AQ68DPGFPxBW+iyMgDoFYFThIA27enMnbs9+Tmurnhhg68/PJA6lbxcTOMMcElaBNFwW2nijhY0aFD2dSsGYWI0KJFPBMmDKJly3jOO695oEMzxpgT5rcuPPxt8ZaDANSrHh3gSA7zeJS33/6Nli1f5/33V/iW33FHV0sSxpigFbSJYtNe54ri7BYJAY7EsXr1Xvr1m8awYV9w8GC2r9LaGGOCXdDeetqd5oxs1yTAXXdkZeXzzDPf89JLP+NyeahTJ5ZXXrmA665rH9C4jDGmrARtojiQ4QxdkRjAp7I3bDjABRe8z5YtKYjA8OFn8te/nketWhXndpgxxpyqoE0UOS4PANEBHIeiSZMaREWFccYZdZk69WJ69GgYsFhMcMjPzyc5OZmcnJzjFzbmJERFRdGwYUPCw8uuRWjQJgqX20kU4aHlV83icnmYOnUJ113XnoSEGCIjw5g9+wYaNKhOWFjQVveYcpScnEy1atVo2rQpIvYcjSlbqsqBAwdITk6mWbNmZbbdoDy6uT2KR0EEQsvpobVFi3bQrdtb3HPPLEaNmutb3qRJTUsSptRycnJISEiwJGH8QkRISEgo8yvWoLyiSM3OByC6HLrhTk3NYfTob5g8eTGq0LhxDS67rLXf92sqL0sSxp/88fsKykSx41A2AI3j/dfiSVX5979X88ADX7N7dwZhYSE8+GAPnnqqr3XgZ4ypUoLynkmet34iyo9XFMuX7+G66z5h9+4Mzj67Eb/+ejsvvDDAkoQJeqGhoXTq1In27dtzySWXkJKS4lu3evVq+vfvT+vWrWnVqhXPPPMMqupbP2vWLLp27Urbtm3p3LkzDz30UCA+Qol+++03hg0bFugwfDZv3kz37t1p2bIlQ4YMIS8v76gy+fn53HTTTXTo0IGkpCT+9re/+dZNmDCB9u3b065dO1599VXf8ocffphvvvmmXD4DqhpUrzPPPFN/3rRfm4yaoddMXaBlyeVyHzH/wAOz9a23lqrb7SnT/Ziqa82aNYEOQWNjY33TN954oz777LOqqpqVlaXNmzfXr7/+WlVVMzMzddCgQTpx4kRVVV25cqU2b95c165dq6qqLpdLJ0+eXKax5efnn/I2rr76al22bFkZRFM2rrnmGv3Xv/6lqqp33HFHsd/ZBx98oEOGDFFV53tv0qSJbt68WVeuXKnt2rXTzMxMzc/P1/POO09///13VVXdsmWLDhgwoNh9Fvc7A5boSR53g/LW07505xmKmtFl1/zr2283M2LETN5442L69GkCwPjxF5TZ9o0pqumjX/llu1ueH1zqsj179mTFCqe7menTp9OrVy8GDhwIQExMDBMnTqRfv37cddddjBs3jtGjR9OmTRvAuTK58847j9pmRkYG99xzD0uWLEFEGDNmDFdddRVxcXFkZDhjyHz88cfMmDGDadOmcfPNNxMVFcVvv/1Gr169+PTTT1m2bBk1azojPbZq1Yoff/yRkJAQhg8fzrZt2wB49dVX6dWr1xH7Tk9PZ8WKFZxxxhkALFq0iPvuu4+cnByio6N55513aN26NdOmTWPJkiVMnDgRgIsvvpiHH36Yfv36MXv2bB5//HHcbjeJiYnMmzev1N9nUarKN998w/Tp0wG46aabePrpp4/63kSEzMxMXC4X2dnZREREUL16dRYvXkz37t2JiXFus/ft25dPP/2URx55hCZNmnDgwAF2795NvXr1TjrG0gjKRLHRO1Z2izqn3iHg3r2ZjBw5h3ffXQ7A+PE/+xKFMZWZ2+1m3rx5vts0q1ev5swzzzyiTIsWLcjIyCAtLY1Vq1aV6lbTM888Q40aNVi5ciUAhw4dOu57kpOTWbBgAaGhobjdbj777DNuueUWfvnlF5o0aULdunW5/vrreeCBB+jduzfbtm3jggsuYO3atUdsZ8mSJbRvf7hXhDZt2vDDDz8QFhbG3Llzefzxx/nkk0+OGce+ffu47bbbmD9/Ps2aNePgwYNHlVm/fj1Dhgwp9v3fffedL8EBHDhwgJo1axIW5hxqGzZsyI4dO45639VXX83nn39O/fr1ycrK4pVXXiE+Pp727dszevRoDhw4QHR0NDNnzqRr166+93Xp0oWffvqJq6666pifqSwEZaLYvN/p56l5YuxJb8PjUf7xj18ZNWouhw7lEBkZyhNP9GHkyLPLKkxjSnQiZ/5lKTs7m06dOrFjxw6SkpIYMGBAmW5/7ty5fPjhh775WrVqHfc911xzDaGhTp3jkCFDGDt2LLfccgsffvih76A8d+5c1qxZ43tPWloaGRkZxMUdPmHctWsXtWvX9s2npqZy00038fvvvyMi5OfnlxjHwoUL6dOnj+8ZhPj4+KPKtG7dmmXLlh33M52IRYsWERoays6dOzl06BDnnHMO559/PklJSYwaNYqBAwcSGxtLp06dfN8TQJ06ddi5c2eZxlKcoKzMLrj1VL/GyXWVsXnzIc455x1uv30Ghw7lMHBgC1atGsETT/QhMjIoc6cxpRYdHc2yZcvYunUrqsqkSZMAaNu2LUuXLj2i7B9//EFcXBzVq1enXbt2R60/EYWbbRZt5x8be/ikr2fPnmzcuJF9+/bx3//+lyuvvBIAj8fDwoULWbZsGcuWLWPHjh1HJImCz1Z4208++STnnnsuq1at4ssvv/StCwsLw+PxHDOekqxfv55OnToV+yrcMAAgISGBlJQUXC4X4Fw5NWjQ4KhtTp8+nUGDBhEeHk6dOnXo1asXS5YsAWDYsGEsXbqU+fPnU6tWLU4//fQj4o6O9n+XQUGZKDZ6x8quHn1yB/Xq1SPZsOEA9erF8eGHVzF79g20bHn0mYMxlVlMTAyvvfYaL7/8Mi6XixtuuIEff/yRuXOdB0qzs7O59957eeSRRwAYOXIkf/3rX9mwYQPgHLinTp161HYHDBjgSz5w+NZT3bp1Wbt2LR6Ph88+++yYcYkIV1xxBQ8++CBJSUkkJDg9RA8cOJDXX3/dV664s/qkpCQ2bjzcc3NqaqrvwDxt2jTf8qZNm7Js2TI8Hg/bt29n0aJFAPTo0YP58+ezefNmgGJvPRVcURT3KnzbqeCznHvuuXz88ccA/POf/+Syyy47apuNGzf2tWDKzMxk4cKFvrqgvXv3ArBt2zY+/fRTrr/+et/7NmzYcMStNr852VrwQL06d+miTUbN0KaPztD8Iq2USjJ79u+ak3O4RcWCBds0JSW71IRFoccAAAt1SURBVO83pixUtFZPqqoXX3yxvvvuu6qqumLFCu3bt6+efvrp2qJFC3366afV4znc6u/LL7/ULl26aJs2bTQpKUlHjhx51PbT09P1xhtv1Hbt2mnHjh31k08+UVXVjz76SJs3b67du3fXu+66S2+66SZVVb3pppv0o48+OmIbixcvVkCnTZvmW7Zv3z7905/+pB06dNCkpCS94447iv187du317S0NFVVXbBggbZq1Uo7deqko0eP1iZNmqiqqsfj0euvv15bt26tl19+ufbt21e//fZbVVWdOXOmdurUSTt27Kjnn39+Kb/VY9u0aZOeddZZ2qJFC7366qs1JydHVVU///xzffLJJ33f2dVXX61t27bVpKQkHTdunO/9vXv31qSkJO3YsaPOnTvXtzwvL0/btGlTbEuxsm71JFqojXQw6NT5TE25YCyJcZEseeL845bfvj2Ve++dzX//u45nnjmXJ57oUw5RGlO8tWvXkpSUFOgwKrVXXnnl/9u7+xipyiuO49+fyO6CbFEkWgXtagQFFUEpUk20ii+orcaKIFUUI7VgbSPWJm2kqZZGrS8kWLQrVYM0vlVaLEFbayyWalgEERCpAiKx2xpFaolRXrq7p388zzrTdXbm7nbnzszu+SSbzNzXsyczc+beO/c81NbWMm3atFKHUlSLFy9mzZo1zJ49+3Pzcr3OJL1qZqM/t3ACFXfqaW9zMwAD++W/8a2pqYU5c1YwbNh9PP30m/TrV8WAAd7+27nubsaMGVRXl274gbQ0NTWldsNjxV25/XRvKBSjDm//lxQNDY1Mn76UdeveB+CSS4Yxd+54Bg36QioxOudKp6amhilTppQ6jKK79NJLU9tXxRWKFguHQQf0zX2z3cqVjZxyykOYQV3d/sybdx4XXDA057LOlYKZeWNAVzTFuJxQcYWiqbmFKmC/dn7GOmbMIM499yhGjfois2adRt92CopzpVBTU8OOHTu81bgrCovjUdTU1HTpdiuuUHy8u4kDgZOPCD9n3bx5BzNnPsecOecydGh48z3zzDfZJ6VxKpzriMGDB9PY2Mj27dtLHYrrplpHuOtKFVco9ja3MGC/KoYfXMutt77I7be/xJ49zdTU7MuiRRMBvEi4stW7d+8uHXnMuTQUtVBIGg/MBXoBD5rZHW3mVwMLgZOAHcAkM9tWaLt1e4xRIx9g06YdAFx99UjuvLNr2xA455wLinYfhaRewCbgbKARWAVMNrONWctcB4wws+mSLgMuNrPc3baiXn0Ot5bdoYnZsGEDqa//mjfxc865Asr1PooxwBYz22pme4EngLb3rl8EPBIfLwLGqcAVvpbdzVRX9+K2285k7drpXiScc67IinlEMQEYb2bT4vMpwMlmdn3WMhviMo3x+dtxmQ/bbOta4Nr49DhgQ1GCrjwDgQ8LLtUzeC4yPBcZnouMo82stjMrVsTFbDObD8wHkLS6s4dP3Y3nIsNzkeG5yPBcZEha3dl1i3nq6R/AYVnPB8dpOZeRtC/Qn3BR2znnXJkoZqFYBQyRdISkKuAyYEmbZZYAV8XHE4A/W6V1KXTOuW6uaKeezKxJ0vXAc4Sfxz5sZm9I+imh3e0S4CHg15K2AP8iFJNC5hcr5grkucjwXGR4LjI8FxmdzkXFtRl3zjmXroprM+6ccy5dXiicc87lVbaFQtJ4SW9J2iLphznmV0t6Ms5fKaku/SjTkSAXN0raKGm9pBckddu7EAvlImu5SySZpG7708gkuZA0Mb423pD0WNoxpiXBe+RwScskvRbfJ+eXIs5ik/SwpA/iPWq55kvSvTFP6yWdmGjDnR1DtZh/hIvfbwNHAlXAOmB4m2WuA+rj48uAJ0sddwlzcQbQNz6e0ZNzEZerBZYDDcDoUsddwtfFEOA14ID4/KBSx13CXMwHZsTHw4FtpY67SLk4DTgR2NDO/POBPwACxgIrk2y3XI8oitL+o0IVzIWZLTOzT+PTBsI9K91RktcFwGzg58DuNINLWZJcfAu4z8w+AjCzD1KOMS1JcmFA6xCX/YF/phhfasxsOeEXpO25CFhoQQOwv6RDCm23XAvFIODvWc8b47Scy5hZE7ATODCV6NKVJBfZriF8Y+iOCuYiHkofZmbPpBlYCSR5XQwFhkp6WVJD7ObcHSXJxS3AFZIagWeB76YTWtnp6OcJUCEtPFwykq4ARgOnlzqWUpC0DzAHmFriUMrFvoTTT18lHGUul3S8mf27pFGVxmRggZndI+krhPu3jjOzllIHVgnK9YjC239kJMkFks4CbgYuNLM9KcWWtkK5qCU0jXxR0jbCOdgl3fSCdpLXRSOwxMz+Y2bvENr+D0kpvjQlycU1wG8AzGwFUENoGNjTJPo8aatcC4W3/8gomAtJo4AHCEWiu56HhgK5MLOdZjbQzOrMrI5wveZCM+t0M7QyluQ98jThaAJJAwmnoramGWRKkuTiXWAcgKRhhELRE8ejXQJcGX/9NBbYaWbvFVqpLE89WfHaf1SchLm4C+gHPBWv579rZheWLOgiSZiLHiFhLp4DzpG0EWgGfmBm3e6oO2Euvg/8StJMwoXtqd3xi6WkxwlfDgbG6zE/AXoDmFk94frM+cAW4FPg6kTb7Ya5cs4514XK9dSTc865MuGFwjnnXF5eKJxzzuXlhcI551xeXiicc87l5YXClR1JzZLWZv3V5Vm2rr1OmR3c54ux++i62PLi6E5sY7qkK+PjqZIOzZr3oKThXRznKkkjE6xzg6S+/+++Xc/lhcKVo11mNjLrb1tK+73czE4gNJu8q6Mrm1m9mS2MT6cCh2bNm2ZmG7skykyc95MszhsALxSu07xQuIoQjxz+KmlN/DslxzLHSnolHoWslzQkTr8ia/oDknoV2N1y4Ki47rg4hsHrsdd/dZx+hzJjgNwdp90i6SZJEwg9tx6N++wTjwRGx6OOzz7c45HHvE7GuYKshm6SfilptcLYE7fGad8jFKxlkpbFaedIWhHz+JSkfgX243o4LxSuHPXJOu20OE77ADjbzE4EJgH35lhvOjDXzEYSPqgbY7uGScCpcXozcHmB/X8deF1SDbAAmGRmxxM6GcyQdCBwMXCsmY0Afpa9spktAlYTvvmPNLNdWbN/G9dtNQl4opNxjie06Wh1s5mNBkYAp0saYWb3Elpqn2FmZ8RWHrOAs2IuVwM3FtiP6+HKsoWH6/F2xQ/LbL2BefGcfDOhb1FbK4CbJQ0GfmdmmyWNA04CVsX2Jn0IRSeXRyXtArYR2lAfDbxjZpvi/EeA7wDzCGNdPCRpKbA06T9mZtslbY19djYDxwAvx+12JM4qQtuW7DxNlHQt4X19CGGAnvVt1h0bp78c91NFyJtz7fJC4SrFTOB94ATCkfDnBiUys8ckrQQuAJ6V9G3CSF6PmNmPEuzj8uwGgpIG5Foo9hYaQ2gyNwG4HjizA//LE8BE4E1gsZmZwqd24jiBVwnXJ34BfEPSEcBNwJfN7CNJCwiN79oS8LyZTe5AvK6H81NPrlL0B96L4wdMITR/+x+SjgS2xtMtvyecgnkBmCDpoLjMACUfU/wtoE7SUfH5FOAv8Zx+fzN7llDATsix7seEtue5LCaMNDaZUDToaJyxod2PgbGSjiGM3vYJsFPSwcB57cTSAJza+j9J2k9SrqMz5z7jhcJVivuBqyStI5yu+STHMhOBDZLWEsalWBh/aTQL+JOk9cDzhNMyBZnZbkJ3zackvQ60APWED92lcXsvkfsc/wKgvvVidpvtfgT8DfiSmb0Sp3U4znjt4x5CV9h1hPGx3wQeI5zOajUf+KOkZWa2nfCLrMfjflYQ8ulcu7x7rHPOubz8iMI551xeXiicc87l5YXCOedcXl4onHPO5eWFwjnnXF5eKJxzzuXlhcI551xe/wXoLJFGIxS2KgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "lw = 2\n", + "plt.figure()\n", + "plt.plot(fpr, tpr, lw=lw, \n", + " label='ROC curve (auc = %.2f)' % (roc_auc))\n", + "plt.plot([0,1],[0,1], color='navy', lw=lw, linestyle='--')\n", + "plt.ylabel('True Positive Rate')\n", + "plt.xlabel('False Positive Rate')\n", + "plt.xlim([0.0,1.0])\n", + "plt.ylim([0.0, 1.05])\n", + "plt.title('SST Membership Inference ROC curve')\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()\n", + "\n", + "import pickle\n", + "results = (fpr, tpr, roc_auc)\n", + "pickle.dump(results, open(\"sst_results\", \"wb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# in_positive = in_input[in_input[:,1] == 1][:,0]\n", + "# in_negative = in_input[in_input[:,1] == 0][:,0]\n", + "\n", + "# out_positive = out_input[out_input[:,1] == 1][:,0]\n", + "# out_negative = out_input[out_input[:,1] == 0][:,0]\n", + "\n", + "\n", + "# plt.figure()\n", + "# sns.distplot(in_positive[:100],label='in positive', kde=True, hist=True, norm_hist=True)\n", + "# sns.distplot(out_positive[:100],label='out positive', kde=True, hist=True, norm_hist=True)\n", + "# plt.legend()\n", + "# plt.xlim([0.4, 1.05])\n", + "# plt.ylim([0.0, 20])\n", + "# plt.title(\"Positive predictions\")\n", + "# plt.show()\n", + "\n", + "# plt.figure()\n", + "# sns.distplot(in_negative,label='in negative')\n", + "# sns.distplot(out_negative,label='out negative')\n", + "# plt.legend()\n", + "# plt.title(\"Negative predictions\")\n", + "# plt.show()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Classification_baselines/SST/SST.ipynb b/Classification_baselines/SST/SST.ipynb new file mode 100644 index 0000000..fc2d712 --- /dev/null +++ b/Classification_baselines/SST/SST.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Classification baseline for Stanford Sentiment Treebank (SST)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python: 3.7.0 (default, Jun 28 2018, 13:15:42) \n", + "[GCC 7.2.0]\n", + "Pytorch: 1.0.0\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from torchtext import data\n", + "from torchtext import datasets \n", + "import sys\n", + "import seaborn as sns\n", + "from sklearn.metrics import roc_curve, auc\n", + "\n", + "sys.path.insert(0, '../../Utils/')\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline \n", + "\n", + "import models\n", + "from train import *\n", + "from metrics import * \n", + "\n", + "print(\"Python: %s\" % sys.version)\n", + "print(\"Pytorch: %s\" % torch.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load SST using Torchtext" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# To fix the following error: OSError: [E050] Can't find model 'en'. It doesn't seem to be a shortcut link, a Python package or a valid path to a data directory.\n", + "# Run: \n", + "# python -m spacy download en\n", + "\n", + "\n", + "TEXT = data.Field(tokenize='spacy')\n", + "LABEL = data.LabelField(tensor_type=torch.LongTensor)\n", + "\n", + "train, val, test = datasets.SST.splits(TEXT, LABEL, root='../../Datasets/SST_data', fine_grained=True)\n", + "\n", + "\n", + "TEXT.build_vocab(train, max_size=25000, vectors=\"glove.6B.100d\", vectors_cache='../../Datasets/SST_data/vector_cache')\n", + "LABEL.build_vocab(train)\n", + "\n", + "BATCH_SIZE = 32\n", + "\n", + "train_itr, val_itr, test_itr = data.BucketIterator.splits(\n", + " (train, val, test), \n", + " batch_size = BATCH_SIZE, \n", + " sort_key= lambda x: len(x.text), \n", + " repeat=False\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create bidirectional LSTM model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "vocab_size = len(TEXT.vocab)\n", + "embedding_size = 100\n", + "hidden_size = 256\n", + "output_size = 5\n", + "\n", + "\n", + "RNN_model = models.RNN(vocab_size, embedding_size, hidden_size, output_size)\n", + "\n", + "pretrained_embeddings = TEXT.vocab.vectors\n", + "RNN_model.embedding.weight.data.copy_(pretrained_embeddings)\n", + "print(\"\")\n", + "\n", + "\n", + "optimizer = torch.optim.Adam(RNN_model.parameters())\n", + "criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "RNN_model = RNN_model.to(device)\n", + "criterion = criterion.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def classification_accuracy(preds, y):\n", + "\n", + " correct = (preds == y).float() #convert into float for division \n", + " acc = correct.sum()/len(correct)\n", + " return acc\n", + "\n", + "def binary_accuracy(preds, y):\n", + "\n", + " rounded_preds = torch.round(preds)\n", + "\n", + " correct = (rounded_preds == y).float() #convert into float for division \n", + " acc = correct.sum()/len(correct)\n", + " return acc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train function" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, iterator, optimizer, criterion):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " model.train()\n", + " \n", + " for batch in iterator:\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " predictions = model(batch.text).squeeze(1)\n", + " \n", + " loss = criterion(predictions, batch.label)\n", + "\n", + " acc = classification_accuracy(predictions.argmax(dim=1), batch.label)\n", + " \n", + " loss.backward()\n", + " \n", + " optimizer.step()\n", + " \n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " return epoch_loss / len(iterator), epoch_acc / len(iterator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation function" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(model, iterator, criterion):\n", + " \n", + " epoch_loss = 0\n", + " epoch_acc = 0\n", + " \n", + " model.eval()\n", + " \n", + " with torch.no_grad():\n", + " \n", + " for batch in iterator:\n", + "\n", + " predictions = model(batch.text).squeeze(1)\n", + " loss = criterion(predictions, batch.label)\n", + " acc = classification_accuracy(predictions.argmax(dim=1), batch.label)\n", + "\n", + " epoch_loss += loss.item()\n", + " epoch_acc += acc.item()\n", + " \n", + " return epoch_loss / len(iterator), epoch_acc / len(iterator)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train classification model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ljt/cyphercat/venv/lib/python3.7/site-packages/torchtext/data/field.py:322: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", + " return Variable(arr, volatile=not train)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 01, Train Loss: 1.514, Train Acc: 31.87%, Val. Loss: 1.387, Val. Acc: 39.22%\n", + "Epoch: 02, Train Loss: 1.368, Train Acc: 39.68%, Val. Loss: 1.362, Val. Acc: 39.36%\n", + "Epoch: 03, Train Loss: 1.272, Train Acc: 43.81%, Val. Loss: 1.362, Val. Acc: 38.16%\n", + "Epoch: 04, Train Loss: 1.173, Train Acc: 48.30%, Val. Loss: 1.334, Val. Acc: 39.90%\n", + "Epoch: 05, Train Loss: 1.073, Train Acc: 53.16%, Val. Loss: 1.417, Val. Acc: 42.22%\n", + "Epoch: 06, Train Loss: 0.982, Train Acc: 58.08%, Val. Loss: 1.394, Val. Acc: 40.80%\n", + "Epoch: 07, Train Loss: 0.891, Train Acc: 62.64%, Val. Loss: 1.480, Val. Acc: 41.24%\n", + "Epoch: 08, Train Loss: 0.805, Train Acc: 67.31%, Val. Loss: 1.611, Val. Acc: 40.03%\n", + "Epoch: 09, Train Loss: 0.723, Train Acc: 71.15%, Val. Loss: 1.709, Val. Acc: 39.27%\n", + "Epoch: 10, Train Loss: 0.661, Train Acc: 73.13%, Val. Loss: 1.826, Val. Acc: 40.12%\n", + "Epoch: 11, Train Loss: 0.606, Train Acc: 75.96%, Val. Loss: 1.836, Val. Acc: 40.61%\n", + "Epoch: 12, Train Loss: 0.554, Train Acc: 77.66%, Val. Loss: 1.943, Val. Acc: 40.56%\n", + "Epoch: 13, Train Loss: 0.508, Train Acc: 79.78%, Val. Loss: 2.183, Val. Acc: 37.97%\n", + "Epoch: 14, Train Loss: 0.466, Train Acc: 81.98%, Val. Loss: 2.082, Val. Acc: 39.58%\n", + "Epoch: 15, Train Loss: 0.444, Train Acc: 82.72%, Val. Loss: 2.173, Val. Acc: 39.84%\n", + "Epoch: 16, Train Loss: 0.401, Train Acc: 84.55%, Val. Loss: 2.419, Val. Acc: 39.48%\n", + "Epoch: 17, Train Loss: 0.372, Train Acc: 85.87%, Val. Loss: 2.403, Val. Acc: 38.91%\n", + "Epoch: 18, Train Loss: 0.355, Train Acc: 86.61%, Val. Loss: 2.613, Val. Acc: 40.20%\n", + "Epoch: 19, Train Loss: 0.328, Train Acc: 87.35%, Val. Loss: 2.622, Val. Acc: 39.48%\n", + "Epoch: 20, Train Loss: 0.305, Train Acc: 88.32%, Val. Loss: 2.826, Val. Acc: 39.26%\n" + ] + } + ], + "source": [ + "n_epochs = 20\n", + "\n", + "for epoch in range(n_epochs):\n", + "\n", + " train_loss, train_acc = train(RNN_model, train_itr, optimizer, criterion)\n", + " valid_loss, valid_acc = evaluate(RNN_model, val_itr, criterion)\n", + " \n", + " print('Epoch: %02d, Train Loss: %.3f, Train Acc: %.2f%%, Val. Loss: %.3f, Val. Acc: %.2f%%' % (epoch+1, train_loss, train_acc*100, valid_loss, valid_acc*100))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluate model on test set" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RNN test accuracy: 0.41\n" + ] + } + ], + "source": [ + "test_loss, test_acc = evaluate(RNN_model, test_itr, criterion)\n", + "\n", + "print('RNN test accuracy: %.2f' % (test_acc))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Utils/models.py b/Utils/models.py index 9597b72..348dcb3 100644 --- a/Utils/models.py +++ b/Utils/models.py @@ -379,6 +379,27 @@ def ft_cnn_classifer(n_classes): kernel_size=7, n_classes=125) +class RNN(torch.nn.Module): + ''' + Bidirectional LSTM for sentiment analysis + ''' + def __init__(self, vocab_size, embedding_size, hidden_size, output_size, n_layers=2, bidirectional=True, dropout=0.5): + super(RNN, self).__init__() + + self.embedding = torch.nn.Embedding(vocab_size, embedding_size) + self.rnn = torch.nn.LSTM(embedding_size, hidden_size, num_layers=n_layers, bidirectional=bidirectional, dropout=dropout) + self.fc = torch.nn.Linear(hidden_size*2, output_size) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x): + + embedded = self.dropout(self.embedding(x)) + output, (hidden, cell) = self.rnn(embedded) + hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)) + + return self.fc(hidden.squeeze(0)) + + def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') diff --git a/requirements.txt b/requirements.txt index 33375fa..a7b6b87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,11 +16,13 @@ pyyaml==4.2b1 requests==2.20.0 scikit-image==0.14.0 scipy==1.1.0 +seaborn==0.9.0 six==1.11.0 sklearn==0.0 soundfile==0.10.02 toolz==0.9.0 -torch==0.4.0 +torch==1.0.0 +torchtext==0.2.3 torchvision==0.2.1 pandas==0.23.2 tqdm==4.23.4