diff --git a/.github/workflow/ci.yml b/.github/workflows/ci.yml similarity index 91% rename from .github/workflow/ci.yml rename to .github/workflows/ci.yml index 2c72e25..5d8e62a 100644 --- a/.github/workflow/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: run: sudo apt-get update && sudo apt-get install -y pandoc - name: Install tectonic - run: sudo snap update && sudo snap install tectonic + run: sudo snap refresh && sudo snap install tectonic - name: Test with pytest run: poetry run pytest tests diff --git a/poetry.lock b/poetry.lock index d3f6288..8467de9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -573,6 +573,17 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "contextlib2" +version = "21.6.0" +description = "Backports and enhancements for the contextlib module" +optional = false +python-versions = ">=3.6" +files = [ + {file = "contextlib2-21.6.0-py2.py3-none-any.whl", hash = "sha256:3fbdb64466afd23abaf6c977627b75b6139a5a3e8ce38405c5b413aed7a0471f"}, + {file = "contextlib2-21.6.0.tar.gz", hash = "sha256:ab1e2bfe1d01d968e1b7e8d9023bc51ef3509bba217bb730cee3827e1ee82869"}, +] + [[package]] name = "dataclasses-json" version = "0.6.6" @@ -663,6 +674,20 @@ files = [ {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, ] +[[package]] +name = "execnet" +version = "2.1.1" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.8" +files = [ + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "executing" version = "2.0.1" @@ -802,6 +827,38 @@ files = [ {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, ] +[[package]] +name = "gitdb" +version = "4.0.11" +description = "Git Object Database" +optional = false +python-versions = ">=3.7" +files = [ + {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"}, + {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"}, +] + +[package.dependencies] +smmap = ">=3.0.1,<6" + +[[package]] +name = "gitpython" +version = "3.1.43" +description = "GitPython is a Python library used to interact with Git repositories" +optional = false +python-versions = ">=3.7" +files = [ + {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"}, + {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"}, +] + +[package.dependencies] +gitdb = ">=4.0.1,<5" + +[package.extras] +doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] + [[package]] name = "greenlet" version = "3.0.3" @@ -1856,6 +1913,22 @@ files = [ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +[[package]] +name = "mock" +version = "5.1.0" +description = "Rolling backport of unittest.mock for all Pythons" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mock-5.1.0-py3-none-any.whl", hash = "sha256:18c694e5ae8a208cdb3d2c20a993ca1a7b0efa258c247a1e565150f477f83744"}, + {file = "mock-5.1.0.tar.gz", hash = "sha256:5e96aad5ccda4718e0a229ed94b2024df75cc2d55575ba5762d31f5767b8767d"}, +] + +[package.extras] +build = ["blurb", "twine", "wheel"] +docs = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "multidict" version = "6.0.5" @@ -2391,6 +2464,39 @@ files = [ qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["docopt", "pytest"] +[[package]] +name = "path" +version = "16.14.0" +description = "A module wrapper for os.path" +optional = false +python-versions = ">=3.8" +files = [ + {file = "path-16.14.0-py3-none-any.whl", hash = "sha256:8ee37703cbdc7cc83835ed4ecc6b638226fb2b43b7b45f26b620589981a109a5"}, + {file = "path-16.14.0.tar.gz", hash = "sha256:dbaaa7efd4602fd6ba8d82890dc7823d69e5de740a6e842d9919b0faaf2b6a8e"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["appdirs", "more-itertools", "packaging", "pygments", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "pywin32"] + +[[package]] +name = "path-py" +version = "12.5.0" +description = "A module wrapper for os.path" +optional = false +python-versions = ">=3.5" +files = [ + {file = "path.py-12.5.0-py3-none-any.whl", hash = "sha256:a43e82eb2c344c3fd0b9d6352f6b856f40b8b7d3d65cc05978b42c3715668496"}, + {file = "path.py-12.5.0.tar.gz", hash = "sha256:8d885e8b2497aed005703d94e0fd97943401f035e42a136810308bff034529a8"}, +] + +[package.dependencies] +path = "*" + +[package.extras] +docs = ["jaraco.packaging (>=3.2)", "rst.linker (>=1.9)", "sphinx"] +testing = ["appdirs", "packaging", "pygments", "pytest (>=3.5,!=3.7.3)", "pytest-black-multipy", "pytest-checkdocs (>=1.2.3)", "pytest-cov", "pytest-flake8"] + [[package]] name = "pexpect" version = "4.9.0" @@ -2745,6 +2851,45 @@ pluggy = ">=1.5,<2.0" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-git" +version = "1.7.0" +description = "Git repository fixture for py.test" +optional = false +python-versions = "*" +files = [ + {file = "pytest-git-1.7.0.tar.gz", hash = "sha256:356fef84eb0d663d2a5eceafb3ff6b2c3043b2b964b1872b67e51979dbbb43f8"}, + {file = "pytest_git-1.7.0-py2.py3-none-any.whl", hash = "sha256:f0737e688bb6d53b4a501d9eba340885e63522ee57e17c24137525c7d9a17361"}, +] + +[package.dependencies] +gitpython = "*" +pytest = "*" +pytest-shutil = "*" + +[[package]] +name = "pytest-shutil" +version = "1.7.0" +description = "A goodie-bag of unix shell and environment tools for py.test" +optional = false +python-versions = "*" +files = [ + {file = "pytest-shutil-1.7.0.tar.gz", hash = "sha256:d8165261de76e7508505c341d94c02b113dc963f274543abca74dbfabd021261"}, + {file = "pytest_shutil-1.7.0-py2.py3-none-any.whl", hash = "sha256:b3568a675cb092c9b15c789ebd3046b79cfaca476868939748729d14557a98ff"}, +] + +[package.dependencies] +contextlib2 = "*" +execnet = "*" +mock = "*" +"path.py" = "*" +pytest = "*" +six = "*" +termcolor = "*" + +[package.extras] +tests = ["pytest"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3410,6 +3555,17 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "smmap" +version = "5.0.1" +description = "A pure Python implementation of a sliding window memory map manager" +optional = false +python-versions = ">=3.7" +files = [ + {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"}, + {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -4356,4 +4512,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "3d29f42d5d54d85cf30292bf7fdbff588c68f9be3c873a24c82c48e13ee9d1db" +content-hash = "c29a8c64d644a7d809b5f4b5e910a4fe50dc0bd19859d60b94a727a270b929a7" diff --git a/pyproject.toml b/pyproject.toml index 6ca0046..a681543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,14 @@ langchain = "^0.2.1" langchain-openai = "^0.1.7" langchain-community = "^0.2.1" langchain-core = "^0.2.1" +gitpython = "^3.1.43" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" jupyter-book = "^1.0.0" jupyterlab = "^4.2.1" pytest = "^8.2.1" +pytest-git = "^1.7.0" [tool.poetry.scripts] test-creation = "test_creation:cli_main" diff --git a/src/test_creation/demo.ipynb b/src/test_creation/demo.ipynb index fe97609..e70f971 100644 --- a/src/test_creation/demo.ipynb +++ b/src/test_creation/demo.ipynb @@ -33,7 +33,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.71s/it]\n" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:33<00:00, 11.16s/it]\n" ] } ], @@ -55,16 +55,6 @@ "output_type": "stream", "text": [ "Report:\n", - " Requirement \\\n", - "ID Title \n", - "2.1 Ensure Data File Loads as Expected Ensure that data-loading functions correctly l... \n", - "3.2 Data in the Expected Format Verify that the data to be ingested matches th... \n", - "3.5 Check for Duplicate Records in Data Check for duplicate records in the dataset and... \n", - "4.2 Verify Data Split Proportion Check that the data is split into training and... \n", - "5.3 Ensure Model Output Shape Aligns with Expectation Ensure the shape of the model's output aligns ... \n", - "6.1 Verify Evaluation Metrics Implementation Verify that the evaluation metrics are correct... \n", - "6.2 Evaluate Model's Performance Against Thresholds Compute evaluation metrics for both the traini... \n", - "\n", " is_Satisfied \\\n", "ID Title \n", "2.1 Ensure Data File Loads as Expected 0.0 \n", @@ -72,47 +62,27 @@ "3.5 Check for Duplicate Records in Data 0.0 \n", "4.2 Verify Data Split Proportion 0.5 \n", "5.3 Ensure Model Output Shape Aligns with Expectation 0.0 \n", - "6.1 Verify Evaluation Metrics Implementation 1.0 \n", + "6.1 Verify Evaluation Metrics Implementation 0.5 \n", "6.2 Evaluate Model's Performance Against Thresholds 0.5 \n", "\n", - " n_files_tested \\\n", - "ID Title \n", - "2.1 Ensure Data File Loads as Expected 3 \n", - "3.2 Data in the Expected Format 3 \n", - "3.5 Check for Duplicate Records in Data 3 \n", - "4.2 Verify Data Split Proportion 3 \n", - "5.3 Ensure Model Output Shape Aligns with Expectation 3 \n", - "6.1 Verify Evaluation Metrics Implementation 3 \n", - "6.2 Evaluate Model's Performance Against Thresholds 3 \n", - "\n", - " Observations \\\n", - "ID Title \n", - "2.1 Ensure Data File Loads as Expected [(test_cross_validation.py) The code does not ... \n", - "3.2 Data in the Expected Format [(test_cross_validation.py) The code does not ... \n", - "3.5 Check for Duplicate Records in Data [(test_cross_validation.py) The code does not ... \n", - "4.2 Verify Data Split Proportion [(test_cross_validation.py) The code tests the... \n", - "5.3 Ensure Model Output Shape Aligns with Expectation [(test_cross_validation.py) The code does not ... \n", - "6.1 Verify Evaluation Metrics Implementation [(test_cross_validation.py) The code does not ... \n", - "6.2 Evaluate Model's Performance Against Thresholds [(test_cross_validation.py) The code does not ... \n", + " n_files_tested \n", + "ID Title \n", + "2.1 Ensure Data File Loads as Expected 3 \n", + "3.2 Data in the Expected Format 3 \n", + "3.5 Check for Duplicate Records in Data 3 \n", + "4.2 Verify Data Split Proportion 3 \n", + "5.3 Ensure Model Output Shape Aligns with Expectation 3 \n", + "6.1 Verify Evaluation Metrics Implementation 3 \n", + "6.2 Evaluate Model's Performance Against Thresholds 3 \n", "\n", - " Function References \n", - "ID Title \n", - "2.1 Ensure Data File Loads as Expected [{'File Path': '../../data/raw/openja/lightfm_... \n", - "3.2 Data in the Expected Format [{'File Path': '../../data/raw/openja/lightfm_... \n", - "3.5 Check for Duplicate Records in Data [{'File Path': '../../data/raw/openja/lightfm_... \n", - "4.2 Verify Data Split Proportion [{'File Path': '../../data/raw/openja/lightfm_... \n", - "5.3 Ensure Model Output Shape Aligns with Expectation [{'File Path': '../../data/raw/openja/lightfm_... \n", - "6.1 Verify Evaluation Metrics Implementation [{'File Path': '../../data/raw/openja/lightfm_... \n", - "6.2 Evaluate Model's Performance Against Thresholds [{'File Path': '../../data/raw/openja/lightfm_... \n", - "\n", - "Score: 2.0/7\n", + "Score: 1.5/7\n", "\n" ] }, { "data": { "text/plain": [ - "'2.0/7'" + "'1.5/7'" ] }, "execution_count": 4, @@ -121,7 +91,7 @@ } ], "source": [ - "parser = ResponseParser(response)\n", + "parser = ResponseParser(response, repo)\n", "parser.get_completeness_score(verbose=True)" ] }, @@ -199,7 +169,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:35<00:00, 11.94s/it]\n" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:31<00:00, 10.40s/it]\n" ] }, { @@ -213,7 +183,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.70s/it]\n" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:31<00:00, 10.36s/it]\n" ] }, { @@ -228,7 +198,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:38<00:00, 12.83s/it]\n" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:39<00:00, 13.16s/it]\n" ] }, { @@ -242,7 +212,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:40<00:00, 13.34s/it]\n" + "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:42<00:00, 14.24s/it]\n" ] } ], @@ -306,14 +276,14 @@ " \n", " \n", " 1\n", - " 0.142857\n", + " 0.214286\n", " ID ...\n", " gpt-3.5-turbo\n", " 2\n", " \n", " \n", " 2\n", - " 0.714286\n", + " 0.642857\n", " ID ...\n", " gpt-4o\n", " 1\n", @@ -332,8 +302,8 @@ "text/plain": [ " score report model_name \\\n", "0 0.214286 ID ... gpt-3.5-turbo \n", - "1 0.142857 ID ... gpt-3.5-turbo \n", - "2 0.714286 ID ... gpt-4o \n", + "1 0.214286 ID ... gpt-3.5-turbo \n", + "2 0.642857 ID ... gpt-4o \n", "3 0.714286 ID ... gpt-4o \n", "\n", " test_no \n", @@ -428,7 +398,7 @@ " Check that the data is split into training and...\n", " 0.5\n", " 3\n", - " [(test_cross_validation.py) The code does spli...\n", + " [(test_cross_validation.py) The code includes ...\n", " [{'File Path': '../../data/raw/openja/lightfm_...\n", " \n", " \n", @@ -488,7 +458,7 @@ "0 3 [(test_cross_validation.py) The code does not ... \n", "1 3 [(test_cross_validation.py) The code does not ... \n", "2 3 [(test_cross_validation.py) The code does not ... \n", - "3 3 [(test_cross_validation.py) The code does spli... \n", + "3 3 [(test_cross_validation.py) The code includes ... \n", "4 3 [(test_cross_validation.py) The code does not ... \n", "5 3 [(test_cross_validation.py) The code does not ... \n", "6 3 [(test_cross_validation.py) The code does not ... \n", @@ -559,12 +529,12 @@ " \n", " \n", " gpt-3.5-turbo\n", - " 0.002551\n", + " 0.000000\n", " 2\n", " \n", " \n", " gpt-4o\n", - " 0.000000\n", + " 0.002551\n", " 2\n", " \n", " \n", @@ -575,8 +545,8 @@ " score \n", " var count\n", "model_name \n", - "gpt-3.5-turbo 0.002551 2\n", - "gpt-4o 0.000000 2" + "gpt-3.5-turbo 0.000000 2\n", + "gpt-4o 0.002551 2" ] }, "execution_count": 11, @@ -597,19 +567,11 @@ "id": "5b1f94c8-1883-4435-84c7-b0687a6e6387", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/vd/r3dvzdx10pxf47gvdqf81r9h0000gn/T/ipykernel_42405/1426530661.py:5: RuntimeWarning: divide by zero encountered in scalar divide\n", - " f_score = score_var[('score', 'var')]['gpt-3.5-turbo'] / score_var[('score', 'var')]['gpt-4o'] # var(prev) / var(curr)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "p-value: 0.0\n", + "p-value: 1.0\n", "\n", "2-tail test:\n", " Successfully reject the null hypothesis: Var(Completeness_Score(Current Version)) == Var(Completeness_Score(Last Week Version))\n" diff --git a/src/test_creation/modules/code_analyzer/analyzers/__init__.py b/src/test_creation/modules/code_analyzer/analyzers/__init__.py index 7f54637..e69de29 100644 --- a/src/test_creation/modules/code_analyzer/analyzers/__init__.py +++ b/src/test_creation/modules/code_analyzer/analyzers/__init__.py @@ -1,7 +0,0 @@ -from abc import ABC - - -class CodeAnalyzer(ABC): - def __init__(self): - pass - diff --git a/src/test_creation/modules/code_analyzer/analyzers/python.py b/src/test_creation/modules/code_analyzer/analyzers/python.py index 43f2385..9db3c5d 100644 --- a/src/test_creation/modules/code_analyzer/analyzers/python.py +++ b/src/test_creation/modules/code_analyzer/analyzers/python.py @@ -1,18 +1,40 @@ +from abc import ABC, abstractmethod import ast +from typing import Union +from pathlib import Path from functools import wraps - -from . import CodeAnalyzer +from collections import defaultdict def assert_have_read_content(f): @wraps(f) - def decorator(*args, **kwargs): - if args[0].content is None: + def decorator(self, *args, **kwargs): + if self.content is None: raise RuntimeError("No content has been read yet.") - return f(*args, **kwargs) + return f(self, *args, **kwargs) + return decorator +class CodeAnalyzer(ABC): + + @abstractmethod + def read(self, file_path: Union[str, Path]) -> None: + pass + + @abstractmethod + def list_imported_packages(self): + pass + + @abstractmethod + def list_all_functions(self): + pass + + @abstractmethod + def contains_test(self): + pass + + class PythonASTCodeAnalyzer(CodeAnalyzer): def __init__(self): super().__init__() @@ -24,6 +46,14 @@ def read(self, file_path: str): self.content = f.read() self._tree = ast.parse(self.content) + @assert_have_read_content + def _get_function_lineno_map(self): + function_lineno_map = defaultdict(int) + for node in ast.walk(self._tree): + if isinstance(node, ast.FunctionDef): + function_lineno_map[node.name] = node.lineno + return function_lineno_map + @assert_have_read_content def list_imported_packages(self): packages = set() @@ -36,7 +66,7 @@ def list_imported_packages(self): @assert_have_read_content def list_all_functions(self): - raise NotImplementedError() + return self._get_function_lineno_map().keys() @assert_have_read_content def contains_test(self): @@ -60,6 +90,15 @@ def read(self, file_path: str): with open(file_path, 'r') as f: self.content = f.readlines() + @assert_have_read_content + def _get_function_lineno_map(self): + function_lineno_map = defaultdict(int) + for line_num, line in enumerate(self.content): + if line.lstrip().startswith('def '): + func_name = line.lstrip().split('(')[0].split(' ')[1] + function_lineno_map[func_name] = line_num + 1 # line starts with 1 + return function_lineno_map + @assert_have_read_content def list_imported_packages(self): packages = set() diff --git a/src/test_creation/modules/code_analyzer/git.py b/src/test_creation/modules/code_analyzer/git.py new file mode 100644 index 0000000..8a43010 --- /dev/null +++ b/src/test_creation/modules/code_analyzer/git.py @@ -0,0 +1,76 @@ +import re +from typing import Union, Optional +from pathlib import Path +from copy import copy + +from git import Repo + + +class GitContext: + def __init__(self, git_dir: Union[str, Path]): + self.git_dir = Path(git_dir) + self.git_repo = Repo(self.git_dir) + + self.branch = self._get_current_branch() + self.host, self.org, self.repo_name = self._get_remote_info() + + self.remote_link_format_map = { + "github": "{host}/{org}/{repo}/blob/{branch}/{path}#L{line_num}", + "gitlab": "{host}/{org}/{repo}/blob/{branch}/{path}#L{line_num}", + "bitbucket": "{host}/{org}/{repo}/src/{branch}/{path}#lines-{" + "line_num}", + "gitee": "{host}/{org}/{repo}/blob/{branch}{path}#L{line_num}" + } + self.remote_protocol = "https" + self.remote_service_family = self.__get_remote_service_family() + + def __get_remote_service_family(self): + result = None + if self.host: + hits = [key for key in self.remote_link_format_map.keys() if + key in self.host] + if hits: + result = hits[0] + return result + + def _get_current_branch(self): + if self.git_repo.head.is_detached: + return self.git_repo.head.commit.hexsha + else: + return self.git_repo.active_branch.name + + def _get_remote_info(self) -> tuple[Optional[str], Optional[str], str]: + if self.git_repo.remotes: + if 'origin' in [r.name for r in self.git_repo.remotes]: + remote = self.git_repo.remote() + else: + remote = self.git_repo.remotes[0] + remote_url = remote.url + # git urls: + # https://git-scm.com/docs/git-clone#URLS + pattern = r"(?:\w+:\/\/)?(?:\w+@)?(.+)[\/:](.+)\/([^\.]+)(?:\.git)?" + host, org, repo_name = re.search(pattern, remote_url).groups() + return host, org, repo_name + else: + print("This git repository has no remote") + return None, None, "." + + def construct_remote_link_to_file(self, file_path: Union[str, Path], + line_num: Optional[int] = None) -> str: + path = Path(file_path) + if path.is_absolute(): + rel_path = path.relative_to(self.git_dir) + else: + rel_path = path + if self.remote_service_family: + f_str = copy(self.remote_link_format_map[self.remote_service_family]) + if line_num is None: + f_str = f_str.split("#")[0] + injected_str = f"{self.remote_protocol}://" + \ + f_str.format(host=self.host, org=self.org, repo=self.repo_name, + branch=self.branch, path=rel_path, + line_num=line_num) + return injected_str + else: + print("No matching service. Using local link instead...") + return f"file://{str(self.git_dir)}/{rel_path}" diff --git a/src/test_creation/modules/code_analyzer/repo.py b/src/test_creation/modules/code_analyzer/repo.py index c060f25..bf9c50e 100644 --- a/src/test_creation/modules/code_analyzer/repo.py +++ b/src/test_creation/modules/code_analyzer/repo.py @@ -1,16 +1,30 @@ import os import logging +from functools import wraps from pathlib import Path from collections import defaultdict -from typing import Dict, List +from typing import Optional from .analyzers.python import PythonNaiveCodeAnalyzer, PythonASTCodeAnalyzer +from .git import GitContext logger = logging.getLogger("test-creation.repo") +def requires_git_context(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + """wrapper function to check if we have git context.""" + if self.git_context is None: + raise RuntimeError("This repository has no git context.") + return func(self, *args, **kwargs) + + return wrapper + + class Repository: def __init__(self, path: str): + if not os.path.exists(path): raise FileNotFoundError(f"Repository {path} does not exist.") elif os.path.isfile(path): @@ -19,6 +33,10 @@ def __init__(self, path: str): if not os.path.exists(self.path / ".git"): # TODO: to be converted to use logger print("Warning: The repository is not a git repository.") + self.git_context = None + else: + self.git_context = GitContext(self.path) + self.files = [] self.fileext_language_map = { '.js': 'JavaScript', @@ -33,6 +51,13 @@ def __init__(self, path: str): '.c': 'C' } self.lf_map = self._get_language_file_map() + self.ffl_map = self._get_file_function_lineno_map() + + @requires_git_context + def get_git_direct_link(self, file: str, + lineno: Optional[int] = None) -> str: + return self.git_context.construct_remote_link_to_file(file, + line_num=lineno) def _get_all_files(self, include_git_dir: bool = False): file_paths = [] @@ -46,14 +71,31 @@ def _get_all_files(self, include_git_dir: bool = False): file_paths.append(f'{root}/{file}') return file_paths - def _get_language_file_map(self): - file_language_map = defaultdict(list) + def _get_language_file_map(self) -> dict[str, list[str]]: + language_file_map = defaultdict(list) files = self._get_all_files() for file in files: for k, v in self.fileext_language_map.items(): if file.endswith(k): - file_language_map[v].append(file) - return file_language_map + language_file_map[v].append(file) + return language_file_map + + def _get_file_function_lineno_map(self) -> dict[str, dict[str, list[str]]]: + file_function_lineno_map = defaultdict(lambda: defaultdict(int)) + for lang, files in self.lf_map.items(): + # TODO: only Python is supported now + if lang == "Python": + ast = PythonASTCodeAnalyzer() + naive = PythonNaiveCodeAnalyzer() + for file in files: + try: + ast.read(file) + file_function_lineno_map[lang][file] = ast._get_function_lineno_map() + except Exception as e: + logger.info("Exception occurred when parsing using ast (Python 2 code?) Using naive parser...") + naive.read(file) + file_function_lineno_map[lang][file] = naive._get_function_lineno_map() + return file_function_lineno_map def list_languages(self): return list(self.lf_map.keys()) @@ -70,7 +112,7 @@ def list_packages(self): packages = list(set(packages)) return packages - def list_test_files(self) -> Dict[str, List[str]]: + def list_test_files(self) -> dict[str, list[str]]: testfiles = defaultdict(list) # for now only Python is supported files = self.lf_map.get("Python", []) diff --git a/src/test_creation/modules/mixins.py b/src/test_creation/modules/mixins.py index 792cfbe..bd51df3 100644 --- a/src/test_creation/modules/mixins.py +++ b/src/test_creation/modules/mixins.py @@ -21,10 +21,16 @@ def _filedump_check(self, output_path: str, exist_ok: bool, expects_directory_if "provided a flag/argument for " "file overwriting?)") elif os.path.exists(normalized_path): - if expects_directory_if_exists and not os.path.isdir(normalized_path): - raise NotADirectoryError("An non-directory already exists in the path but the write operation is expecting to overwrite a directory.") - elif not expects_directory_if_exists and not os.path.isfile(normalized_path): - raise IsADirectoryError("An non-file object already exists in the path but the write operation is expecting to overwrite a file.") + if expects_directory_if_exists and not os.path.isdir( + normalized_path): + raise NotADirectoryError("An non-directory already exists in " + "the path but the write operation is" + " expecting to overwrite a directory.") + elif not expects_directory_if_exists and not os.path.isfile( + normalized_path): + raise IsADirectoryError("An non-file object already exists in " + "the path but the write operation is " + "expecting to overwrite a file.") if not os.access(normalized_path, os.W_OK): raise PermissionError(f"Write permission is not granted for the output path: {normalized_path}") diff --git a/src/test_creation/modules/workflow/evaluator.py b/src/test_creation/modules/workflow/evaluator.py index 3f4088d..88920d4 100644 --- a/src/test_creation/modules/workflow/evaluator.py +++ b/src/test_creation/modules/workflow/evaluator.py @@ -8,7 +8,6 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter, Language from langchain_core.language_models import LanguageModelLike from langchain_core.documents import Document -from pydantic import ValidationError from ..checklist.checklist import Checklist from ..code_analyzer.repo import Repository @@ -30,22 +29,24 @@ class TestEvaluator(PipelineRunner, ABC): """Abstract base class for test evaluators i.e. class object to run evaluation of test files from a given repository. """ - def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, repository: Repository, - checklist: Checklist): + + def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, + repository: Repository, checklist: Checklist): self.llm = llm self.checklist = checklist self.repository = repository self.prompt_format = prompt_format - self.test_items = None + self._test_items = None self.chain = self.prompt_format.prompt | self.llm | self.prompt_format.parser class PerFileTestEvaluator(TestEvaluator): """Concrete test evaluator that performs per-file evaluation.""" - def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, repository: Repository, - checklist: Checklist, retries: int = 3): + + def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, + repository: Repository, checklist: Checklist, retries: int = 3): super().__init__(llm, prompt_format, repository, checklist) self.retries = retries @@ -53,7 +54,8 @@ def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, reposito if not self._files: print("File loader returned no files!") - self._test_items = self.checklist.get_all_tests(['ID', 'Title', 'Requirement']) + self._test_items = self.checklist.get_all_tests(['ID', 'Title', + 'Requirement']) if not self._test_items: print("Loaded checklist successfully, but it contains no test items!") @@ -67,15 +69,18 @@ def _load_test_file_into_splits(file_path: str) -> List[Document]: def _validate_response(self, raw_response: dict) -> None: """Validation logics that are not covered by pydantic or langchain.""" - # ensures the number of items in the response is the same as provided checklists - if len(raw_response['results']) != len(self.test_items): - raise ValidationError("Number of items returned from LLM does not match that in checklist.") + # ensures the number of items in the response is the same as provided + # checklists + if len(raw_response['results']) != len(self._test_items): + raise AssertionError("Number of items returned from LLM does not match that in checklist.") + if not all(['Functions' in item for item in raw_response['results']]): + raise AssertionError("Not all items returned contain the attribute `Functions`.") def run(self, verbose: bool = False) -> EvaluationResponse: eval_response = EvaluationResponse( model={'name': self.llm.model_name, 'temperature': self.llm.temperature}, - repository_path=self.repository.path, - checklist_path=self.checklist.path + repository={'path': self.repository.path, 'object': self.repository}, + checklist={'path': self.checklist.path, 'object': self.checklist} ) for fp in tqdm(self._files): @@ -84,7 +89,6 @@ def run(self, verbose: bool = False) -> EvaluationResponse: splits = self._load_test_file_into_splits(fp) if verbose: print(f"# splits: {len(self._files)}") - # FIXME: it sometimes tests only part of the checklist items response = None retry_count = 0 @@ -107,8 +111,11 @@ def run(self, verbose: bool = False) -> EvaluationResponse: self._validate_response(response) except Exception as e: + if verbose: + print(f"error occurred: {e.__class__.__name__} - {str(e)}") errors.append({'name': e.__class__.__name__, 'description': str(e)}) retry_count += 1 + response = None continue if not response: diff --git a/src/test_creation/modules/workflow/parse.py b/src/test_creation/modules/workflow/parse.py index 6272d65..e8e9b1e 100644 --- a/src/test_creation/modules/workflow/parse.py +++ b/src/test_creation/modules/workflow/parse.py @@ -10,11 +10,36 @@ class ResponseParser(ExportableMixin): def __init__(self, response: EvaluationResponse): + # FIXME: respository is required to extract the line numbers for functions + # I added an optional argument "respository" here, can't think of any better way to handle it yet super().__init__() self.response = response self.evaluation_report = None + self.repository = self.response.repository.object + self.git_context = self.repository.git_context + self.items = [] + + def _parse_items(self): + items = [] + for result in self.response.call_results: + response = result.parsed_response['results'] + for item in response: + fp = result.files_evaluated[0] + item['File Path'] = fp + if self.repository: + item['lineno'] = [self.repository.ffl_map['Python'][fp][func] for func in item['Functions']] + else: + item['lineno'] = [] + item['Referenced Functions'] = [ + f"[{func}]({self.repository.get_git_direct_link(fp, lineno)})" + for func, lineno in zip(item['Functions'], item['lineno']) + ] + items.append(item) + self.items = items + return items def get_completeness_score(self, score_format: str = 'fraction', verbose: bool = False) -> Optional[float]: + """Compute Evaluation Report and Completeness Score.""" # TODO: change this after putting the logic to load data from JSON file # instead of from a Python object. @@ -30,16 +55,10 @@ def get_completeness_score(self, score_format: str = 'fraction', verbose: bool = print("failed to obtain valid response, cannot calculate completeness score") return None - report = [] - for result in self.response.call_results: - response = result.parsed_response['results'] - for item in response: - item['file'] = result.files_evaluated[0] # FIXME: it might fail if the evaluation is on multiple files - report.append(item) + items = self._parse_items() - report_df = pd.DataFrame(report) - report_df = report_df.rename(columns={"file": "File Path"}) - report_df['Function References'] = report_df[['File Path', 'Functions']].to_dict(orient='records') + report_df = pd.DataFrame(items) + report_df['Function References'] = report_df[['File Path', 'Referenced Functions']].to_dict(orient='records') report_df['Observation'] = '(' + report_df['File Path'].apply(lambda x: os.path.split(x)[-1]) + ') ' + \ report_df['Observation'] report_df = report_df.groupby(['ID', 'Title']).agg({ @@ -58,7 +77,7 @@ def get_completeness_score(self, score_format: str = 'fraction', verbose: bool = if verbose: print("Report:") - print(report_df) + print(report_df[['is_Satisfied', 'n_files_tested']]) print() print(f'Score: {score}') print() diff --git a/src/test_creation/modules/workflow/prompt_format.py b/src/test_creation/modules/workflow/prompt_format.py index 57ea354..205c161 100644 --- a/src/test_creation/modules/workflow/prompt_format.py +++ b/src/test_creation/modules/workflow/prompt_format.py @@ -25,7 +25,7 @@ class TestItemEvaluation(BaseModel): Title: str = Field(description="The corresponding `Title` of the checklist item provided") Requirement: str = Field(description="The corresponding `Requirement` of the checklist item provided") Observation: str = Field(description="Your detailed observation of the code in accordance to the given checklist item") - Functions: List[str] = Field(description="Test functions that satisfy the given requirement (if any)") + Functions: List[str] = Field(description="Test functions that satisfy the given requirement. If no function satisfies, an empty list i.e. [] should be returned.") Evaluation: str = Field(description="The summarized evaluation. Must be one of Satisfied/Partially Satisfied/Not Satisfied.") Score: float = Field(description="The score obtained from the given evaluation (1 for Satisfied / 0.5 for Partially Satisfied / 0 for Not Satisfied)") diff --git a/src/test_creation/modules/workflow/response.py b/src/test_creation/modules/workflow/response.py index 386bd95..07643bf 100644 --- a/src/test_creation/modules/workflow/response.py +++ b/src/test_creation/modules/workflow/response.py @@ -2,14 +2,31 @@ from pathlib import Path from typing import List, Dict, Any, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict +from ..code_analyzer.repo import Repository +from ..checklist.checklist import Checklist -class LLM(BaseModel): + +class LLMInfo(BaseModel): name: str = Field(description="Name of the LLM used") temperature: float = Field(description="Temperature of the LLM") +class RepositoryInfo(BaseModel): + path: Union[str, Path] = Field(description="Path of the repository") + object: Repository = Field(description="Repository object") + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ChecklistInfo(BaseModel): + path: Union[str, Path] = Field(description="Path of the checklist") + object: Checklist = Field(description="Checklist object") + + model_config = ConfigDict(arbitrary_types_allowed=True) + + class Error(BaseModel): name: str = Field(description="Class name of the error") description: str = Field(description="Description of the error") @@ -35,8 +52,14 @@ class EvaluationResponse(BaseModel): name temperature } - repository_path - checklist_path + repository { + path + object + } + checklist { + path + object + } call_results [{ start_time end_time @@ -52,7 +75,7 @@ class EvaluationResponse(BaseModel): }] } """ - model: LLM = Field(description="LLM-related information") - repository_path: Union[str, Path] = Field(description="Repository path") - checklist_path: Union[str, Path] = Field(description="Checklist path") + model: LLMInfo = Field(description="LLM-related information") + repository: RepositoryInfo = Field(description="Repository-related information") + checklist: ChecklistInfo = Field(description="Checklist-related information") call_results: List[CallResult] = Field(description="List of call results", default=[]) diff --git a/tests/test_repo.py b/tests/test_repo.py new file mode 100644 index 0000000..fb29a73 --- /dev/null +++ b/tests/test_repo.py @@ -0,0 +1,91 @@ +from pathlib import Path + +import pytest +from test_creation.modules.code_analyzer import repo as r +from test_creation.modules.code_analyzer.git import GitContext + + +@pytest.fixture() +def test_git_repo(git_repo): + # The fixture derives from `workspace` in `pytest-shutil`, so they contain + # a handle to the path `path` object (see https://path.readthedocs.io/) + path = git_repo.workspace + txt_file = path / 'hello.txt' + txt_file.write_text('hello world!') + + py_file = Path(path / 'src/python/main.py') + py_file.parent.mkdir(parents=True, exist_ok=True) + py_file.write_text('print("hello world!")') + + # We can run commands relative to the working directory + git_repo.run('git add .') + + # It's better to use the GitPython api directly - the 'api' attribute is + # a handle to the repository object. + git_repo.api.index.commit("Initial commit") + + # The fixture has a URI property you can use in downstream systems + assert git_repo.uri.startswith('file://') + + return git_repo + + +################################################################################ +# Repository # +################################################################################ +def test_repository_should_be_able_to_read_git_repo(test_git_repo): + path = test_git_repo.workspace + repo = r.Repository(path) + assert any(['src/python/main.py' in file for file in repo._get_all_files()]) + + +################################################################################ +# GitContext # +################################################################################ +@pytest.mark.parametrize( + "fixture_name, remote_name, remote_url, expected", + [ + ( + "test_git_repo", + "origin", + "git@github.internal.com:UBC-MDS/testing-repo_1234.git", + ("github.internal.com", "UBC-MDS", "testing-repo_1234") + ), + ( + "test_git_repo", + "export", + "ssh://git@github.internal.com:UBC-MDS/testing-repo_1234.git", + ("github.internal.com", "UBC-MDS", "testing-repo_1234") + ), + ( + "test_git_repo", + "internal", + "https://github.com:8080/UBC-MDS/test-creation.git", + ("github.com:8080", "UBC-MDS", "test-creation") + ), + ( + "test_git_repo", + "origin", + "http://gitlab.example.com:8080/UBC-MDS/test-creation.git", + ("gitlab.example.com:8080", "UBC-MDS", "test-creation") + ), + ( + "test_git_repo", + "a", + "ftp://github.com/SoloSynth1/Solosynth1", + ("github.com", "SoloSynth1", "Solosynth1") + ), + ] +) +def test_git_context_can_extract_remote_git_urls(fixture_name, remote_name, + remote_url, expected, request): + repo = request.getfixturevalue(fixture_name) + repo.api.create_remote(remote_name, remote_url) + gc = GitContext(repo.workspace) + assert (gc.host, gc.org, gc.repo_name) == expected + + +def test_git_context_gives_out_local_link_when_no_remote(test_git_repo): + context = GitContext(test_git_repo.workspace) + link = context.construct_remote_link_to_file("src/python/main.py") + assert link == f"file://{test_git_repo.workspace}/src/python/main.py" \ No newline at end of file diff --git a/tests/test_test_creation.py b/tests/test_test_creation.py deleted file mode 100644 index 8b50deb..0000000 --- a/tests/test_test_creation.py +++ /dev/null @@ -1 +0,0 @@ -#from test_creation import test_creation diff --git a/tests/test_utils.py b/tests/test_utils.py index a978032..cdda766 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,31 +1,21 @@ from pathlib import Path +import pytest from test_creation.modules.utils import get_extension -def test_extension_from_string_can_be_extracted_correctly(): - correct_path = "checklist/checklist.csv" - assert get_extension(correct_path) == "csv" - - -def test_extension_from_path_can_be_extracted_correctly(): - correct_path = Path("checklist/checklist.csv") - assert get_extension(correct_path) == "csv" - - -def test_extension_extracted_is_all_lower_cased(): - path = "ALL/CAPITAL/PATH/TEST.ZIP" - assert get_extension(path) == "zip" - - -def test_only_last_extension_will_be_extracted(): - path = "test/multi_ext.tar.gz" - assert get_extension(path) == "gz" - - -def test_file_with_no_extension_will_produce_empty_string(): - path = "test/README" - assert get_extension(path) == "" +@pytest.mark.parametrize( + "path, expected", + [ + ("checklist/checklist.csv", "csv"), + (Path("checklist/checklist.csv"), "csv"), + ("ALL/CAPITAL/PATH/TEST.ZIP", "zip"), + ("test/multi_ext.tar.gz", "gz"), + (Path("test/README"), "") + ] +) +def test_extension_from_string_can_be_extracted_correctly(path, expected): + assert get_extension(path) == expected def test_extracted_extension_does_not_start_with_dot():