From 65f82fc4b52cd7fef23808bb2116b50d97ff7118 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:31:02 +0000 Subject: [PATCH 1/3] feat: add tests for main.py --- src/test_main.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/test_main.py diff --git a/src/test_main.py b/src/test_main.py new file mode 100644 index 0000000..9027c45 --- /dev/null +++ b/src/test_main.py @@ -0,0 +1,31 @@ +import unittest +from unittest.mock import Mock, patch +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from main import Net + +class TestMain(unittest.TestCase): + def setUp(self): + self.mock_model = Mock(spec=Net) + self.mock_model.forward.return_value = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + self.mock_data_loader = Mock(spec=DataLoader) + self.mock_data_loader.return_value = [torch.randn(64, 1, 28, 28), torch.randint(0, 10, (64,))] + + def test_model_initialization(self): + model = Net() + self.assertIsInstance(model, Net) + + def test_model_forward_pass(self): + input_data = torch.randn(64, 1, 28, 28) + output = self.mock_model.forward(input_data) + self.assertEqual(output.shape, (64, 10)) + + def test_data_loader(self): + batch = next(iter(self.mock_data_loader)) + self.assertEqual(len(batch), 2) + self.assertEqual(batch[0].shape, (64, 1, 28, 28)) + self.assertEqual(batch[1].shape, (64,)) + +if __name__ == "__main__": + unittest.main() From de4594aa2cfb4918dce7097a7fad36b28a9641a4 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:32:17 +0000 Subject: [PATCH 2/3] feat: Updated pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b650fdb..0530465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = ["your-name "] python = "^3.9" fastapi = "0.68.0" uvicorn = "0.15.0" -torch = "*" +torch = "1.10.0+cpu" torchvision = "*" pillow = "*" pylint = "*" From e153d497d5aad1631dcd5d1b1a664be54b429b60 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:32:54 +0000 Subject: [PATCH 3/3] Sandbox run pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0530465..f96d61f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,3 @@ pillow = "*" pylint = "*" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" -