From 9fa2b98d49672ad92e68730f747602a257f7e5d3 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:39:27 +0000 Subject: [PATCH 1/2] feat: add tests using mocker to main.py --- src/test_main.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 src/test_main.py diff --git a/src/test_main.py b/src/test_main.py new file mode 100644 index 0000000..d6a94ff --- /dev/null +++ b/src/test_main.py @@ -0,0 +1,33 @@ +import pytest +from pytest_mock import MockerFixture +from main import Net, transform +from api import app +from fastapi.testclient import TestClient +from PIL import Image +import torch +import torch.nn as nn +import io + +def test_net_init(mocker: MockerFixture): + mock_super_init = mocker.patch('torch.nn.Module.__init__') + net = Net() + mock_super_init.assert_called_once() + +def test_net_forward(mocker: MockerFixture): + mock_input = mocker.patch('torch.Tensor') + mock_relu = mocker.patch('torch.nn.functional.relu') + mock_log_softmax = mocker.patch('torch.nn.functional.log_softmax') + net = Net() + net.forward(mock_input) + mock_relu.assert_any_call(net.fc1(mock_input.view(-1, 28 * 28))) + mock_relu.assert_any_call(net.fc2(mock_relu.return_value)) + mock_log_softmax.assert_called_once_with(net.fc3(mock_relu.return_value), dim=1) + +def test_predict(mocker: MockerFixture): + mock_file = mocker.patch('fastapi.UploadFile') + mock_image_open = mocker.patch('PIL.Image.open') + mock_image_open.return_value.convert.return_value = Image.new('L', (28, 28)) + client = TestClient(app) + response = client.post("/predict/", files={"file": ("filename", io.BytesIO(), "image/png")}) + assert response.status_code == 200 + assert 'prediction' in response.json() From 858bceb6e624309d6a9c915651e4cf48f028da7f Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:44:07 +0000 Subject: [PATCH 2/2] feat: Updated src/test_main.py --- src/test_main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_main.py b/src/test_main.py index d6a94ff..25dadda 100644 --- a/src/test_main.py +++ b/src/test_main.py @@ -29,5 +29,5 @@ def test_predict(mocker: MockerFixture): mock_image_open.return_value.convert.return_value = Image.new('L', (28, 28)) client = TestClient(app) response = client.post("/predict/", files={"file": ("filename", io.BytesIO(), "image/png")}) - assert response.status_code == 200 - assert 'prediction' in response.json() + assert response.status_code == 200, "Expected status code to be 200" + assert 'prediction' in response.json(), "Expected 'prediction' in response"