diff --git a/hstest/__init__.py b/hstest/__init__.py index 59b896fe..28b1228d 100644 --- a/hstest/__init__.py +++ b/hstest/__init__.py @@ -3,6 +3,7 @@ 'DjangoTest', 'FlaskTest', 'PlottingTest', + 'SQLTest', 'TestCase', 'SimpleTestCase', @@ -24,6 +25,7 @@ from hstest.stage import DjangoTest from hstest.stage import FlaskTest from hstest.stage import StageTest +from hstest.stage import SQLTest from hstest.test_case import CheckResult from hstest.test_case import SimpleTestCase from hstest.test_case import TestCase diff --git a/hstest/stage/__init__.py b/hstest/stage/__init__.py index b196aec7..a87bba82 100644 --- a/hstest/stage/__init__.py +++ b/hstest/stage/__init__.py @@ -3,11 +3,13 @@ 'DjangoTest', 'FlaskTest', 'PlottingTest', + 'SQLTest' ] from hstest.stage.django_test import DjangoTest from hstest.stage.flask_test import FlaskTest from hstest.stage.stage_test import StageTest +from hstest.stage.sql_test import SQLTest try: from hstest.stage.plotting_test import PlottingTest diff --git a/hstest/stage/sql_test.py b/hstest/stage/sql_test.py new file mode 100644 index 00000000..90ab3242 --- /dev/null +++ b/hstest/stage/sql_test.py @@ -0,0 +1,39 @@ +import os + +from exception.outcomes import ErrorWithFeedback, WrongAnswer +from hstest.stage.stage_test import StageTest +from hstest.testing.runner.sql_runner import SQLRunner +import re + + +class SQLTest(StageTest): + runner = SQLRunner() + queries = dict() + + def run_tests(self, *, debug=False, is_unittest: bool = False): + self.find_sql_files() + return super(SQLTest, self).run_tests() + + def find_sql_files(self): + root_folder = os.getcwd() + + is_sql_file_found: bool = False + + for file in os.listdir(root_folder): + if file.endswith('.sql'): + is_sql_file_found = True + self.parse_sql_file(file) + if not is_sql_file_found: + raise ErrorWithFeedback("Can't find any SQL file!") + + def parse_sql_file(self, file_name: str) -> None: + file_path = os.path.join(os.getcwd(), file_name) + + with open(file_path, 'r') as file: + lines = file.readlines() + sql_content = " ".join(lines).replace("\n", "") + commands = re.findall("(\\w+)\\s+?=\\s+?\"(.+?)\"", sql_content) + + for (name, query) in commands: + if name in self.queries: + self.queries[name] = query diff --git a/hstest/testing/runner/sql_runner.py b/hstest/testing/runner/sql_runner.py new file mode 100644 index 00000000..3246debd --- /dev/null +++ b/hstest/testing/runner/sql_runner.py @@ -0,0 +1,21 @@ +import typing + +from hstest.test_case.check_result import CheckResult +from hstest.testing.runner.test_runner import TestRunner + +if typing.TYPE_CHECKING: + from hstest.testing.test_run import TestRun + + +class SQLRunner(TestRunner): + + def test(self, test_run: 'TestRun'): + test_case = test_run.test_case + + try: + result = test_case.dynamic_testing() + return result + except BaseException as ex: + test_run.set_error_in_test(ex) + + return CheckResult.from_error(test_run.error_in_test) diff --git a/tests/projects/sql/main.sql b/tests/projects/sql/main.sql new file mode 100644 index 00000000..511d872b --- /dev/null +++ b/tests/projects/sql/main.sql @@ -0,0 +1,5 @@ +test = "UPDATE Customers +SET ContactName = 'Alfred Schmidt', City= 'Frankfurt' +WHERE CustomerID = 1"; + +create_table = "CREATE table"; \ No newline at end of file diff --git a/tests/projects/sql/test.py b/tests/projects/sql/test.py new file mode 100644 index 00000000..46167a1b --- /dev/null +++ b/tests/projects/sql/test.py @@ -0,0 +1,16 @@ +from hstest import SQLTest, dynamic_test, correct, wrong + + +class TestSQLProject(SQLTest): + + queries = { + 'create_table': None, + 'test': None + } + + @dynamic_test() + def simple_test(self): + for query in self.queries: + if self.queries[query] is None: + return wrong(f"Can't find '{query}' query from SQL files!") + return correct()