diff --git a/.github/workflow/ci.yml b/.github/workflows/ci.yml
similarity index 91%
rename from .github/workflow/ci.yml
rename to .github/workflows/ci.yml
index 2c72e25..5d8e62a 100644
--- a/.github/workflow/ci.yml
+++ b/.github/workflows/ci.yml
@@ -28,7 +28,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y pandoc
- name: Install tectonic
- run: sudo snap update && sudo snap install tectonic
+ run: sudo snap refresh && sudo snap install tectonic
- name: Test with pytest
run: poetry run pytest tests
diff --git a/poetry.lock b/poetry.lock
index d3f6288..8467de9 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -573,6 +573,17 @@ traitlets = ">=4"
[package.extras]
test = ["pytest"]
+[[package]]
+name = "contextlib2"
+version = "21.6.0"
+description = "Backports and enhancements for the contextlib module"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "contextlib2-21.6.0-py2.py3-none-any.whl", hash = "sha256:3fbdb64466afd23abaf6c977627b75b6139a5a3e8ce38405c5b413aed7a0471f"},
+ {file = "contextlib2-21.6.0.tar.gz", hash = "sha256:ab1e2bfe1d01d968e1b7e8d9023bc51ef3509bba217bb730cee3827e1ee82869"},
+]
+
[[package]]
name = "dataclasses-json"
version = "0.6.6"
@@ -663,6 +674,20 @@ files = [
{file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"},
]
+[[package]]
+name = "execnet"
+version = "2.1.1"
+description = "execnet: rapid multi-Python deployment"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
+ {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
+]
+
+[package.extras]
+testing = ["hatch", "pre-commit", "pytest", "tox"]
+
[[package]]
name = "executing"
version = "2.0.1"
@@ -802,6 +827,38 @@ files = [
{file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
]
+[[package]]
+name = "gitdb"
+version = "4.0.11"
+description = "Git Object Database"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "gitdb-4.0.11-py3-none-any.whl", hash = "sha256:81a3407ddd2ee8df444cbacea00e2d038e40150acfa3001696fe0dcf1d3adfa4"},
+ {file = "gitdb-4.0.11.tar.gz", hash = "sha256:bf5421126136d6d0af55bc1e7c1af1c397a34f5b7bd79e776cd3e89785c2b04b"},
+]
+
+[package.dependencies]
+smmap = ">=3.0.1,<6"
+
+[[package]]
+name = "gitpython"
+version = "3.1.43"
+description = "GitPython is a Python library used to interact with Git repositories"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "GitPython-3.1.43-py3-none-any.whl", hash = "sha256:eec7ec56b92aad751f9912a73404bc02ba212a23adb2c7098ee668417051a1ff"},
+ {file = "GitPython-3.1.43.tar.gz", hash = "sha256:35f314a9f878467f5453cc1fee295c3e18e52f1b99f10f6cf5b1682e968a9e7c"},
+]
+
+[package.dependencies]
+gitdb = ">=4.0.1,<5"
+
+[package.extras]
+doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"]
+test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
+
[[package]]
name = "greenlet"
version = "3.0.3"
@@ -1856,6 +1913,22 @@ files = [
{file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"},
]
+[[package]]
+name = "mock"
+version = "5.1.0"
+description = "Rolling backport of unittest.mock for all Pythons"
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "mock-5.1.0-py3-none-any.whl", hash = "sha256:18c694e5ae8a208cdb3d2c20a993ca1a7b0efa258c247a1e565150f477f83744"},
+ {file = "mock-5.1.0.tar.gz", hash = "sha256:5e96aad5ccda4718e0a229ed94b2024df75cc2d55575ba5762d31f5767b8767d"},
+]
+
+[package.extras]
+build = ["blurb", "twine", "wheel"]
+docs = ["sphinx"]
+test = ["pytest", "pytest-cov"]
+
[[package]]
name = "multidict"
version = "6.0.5"
@@ -2391,6 +2464,39 @@ files = [
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
testing = ["docopt", "pytest"]
+[[package]]
+name = "path"
+version = "16.14.0"
+description = "A module wrapper for os.path"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "path-16.14.0-py3-none-any.whl", hash = "sha256:8ee37703cbdc7cc83835ed4ecc6b638226fb2b43b7b45f26b620589981a109a5"},
+ {file = "path-16.14.0.tar.gz", hash = "sha256:dbaaa7efd4602fd6ba8d82890dc7823d69e5de740a6e842d9919b0faaf2b6a8e"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["appdirs", "more-itertools", "packaging", "pygments", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "pywin32"]
+
+[[package]]
+name = "path-py"
+version = "12.5.0"
+description = "A module wrapper for os.path"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "path.py-12.5.0-py3-none-any.whl", hash = "sha256:a43e82eb2c344c3fd0b9d6352f6b856f40b8b7d3d65cc05978b42c3715668496"},
+ {file = "path.py-12.5.0.tar.gz", hash = "sha256:8d885e8b2497aed005703d94e0fd97943401f035e42a136810308bff034529a8"},
+]
+
+[package.dependencies]
+path = "*"
+
+[package.extras]
+docs = ["jaraco.packaging (>=3.2)", "rst.linker (>=1.9)", "sphinx"]
+testing = ["appdirs", "packaging", "pygments", "pytest (>=3.5,!=3.7.3)", "pytest-black-multipy", "pytest-checkdocs (>=1.2.3)", "pytest-cov", "pytest-flake8"]
+
[[package]]
name = "pexpect"
version = "4.9.0"
@@ -2745,6 +2851,45 @@ pluggy = ">=1.5,<2.0"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+[[package]]
+name = "pytest-git"
+version = "1.7.0"
+description = "Git repository fixture for py.test"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-git-1.7.0.tar.gz", hash = "sha256:356fef84eb0d663d2a5eceafb3ff6b2c3043b2b964b1872b67e51979dbbb43f8"},
+ {file = "pytest_git-1.7.0-py2.py3-none-any.whl", hash = "sha256:f0737e688bb6d53b4a501d9eba340885e63522ee57e17c24137525c7d9a17361"},
+]
+
+[package.dependencies]
+gitpython = "*"
+pytest = "*"
+pytest-shutil = "*"
+
+[[package]]
+name = "pytest-shutil"
+version = "1.7.0"
+description = "A goodie-bag of unix shell and environment tools for py.test"
+optional = false
+python-versions = "*"
+files = [
+ {file = "pytest-shutil-1.7.0.tar.gz", hash = "sha256:d8165261de76e7508505c341d94c02b113dc963f274543abca74dbfabd021261"},
+ {file = "pytest_shutil-1.7.0-py2.py3-none-any.whl", hash = "sha256:b3568a675cb092c9b15c789ebd3046b79cfaca476868939748729d14557a98ff"},
+]
+
+[package.dependencies]
+contextlib2 = "*"
+execnet = "*"
+mock = "*"
+"path.py" = "*"
+pytest = "*"
+six = "*"
+termcolor = "*"
+
+[package.extras]
+tests = ["pytest"]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -3410,6 +3555,17 @@ files = [
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
+[[package]]
+name = "smmap"
+version = "5.0.1"
+description = "A pure Python implementation of a sliding window memory map manager"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "smmap-5.0.1-py3-none-any.whl", hash = "sha256:e6d8668fa5f93e706934a62d7b4db19c8d9eb8cf2adbb75ef1b675aa332b69da"},
+ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"},
+]
+
[[package]]
name = "sniffio"
version = "1.3.1"
@@ -4356,4 +4512,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
-content-hash = "3d29f42d5d54d85cf30292bf7fdbff588c68f9be3c873a24c82c48e13ee9d1db"
+content-hash = "c29a8c64d644a7d809b5f4b5e910a4fe50dc0bd19859d60b94a727a270b929a7"
diff --git a/pyproject.toml b/pyproject.toml
index 6ca0046..a681543 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,12 +18,14 @@ langchain = "^0.2.1"
langchain-openai = "^0.1.7"
langchain-community = "^0.2.1"
langchain-core = "^0.2.1"
+gitpython = "^3.1.43"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
jupyter-book = "^1.0.0"
jupyterlab = "^4.2.1"
pytest = "^8.2.1"
+pytest-git = "^1.7.0"
[tool.poetry.scripts]
test-creation = "test_creation:cli_main"
diff --git a/src/test_creation/demo.ipynb b/src/test_creation/demo.ipynb
index fe97609..e70f971 100644
--- a/src/test_creation/demo.ipynb
+++ b/src/test_creation/demo.ipynb
@@ -33,7 +33,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.71s/it]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:33<00:00, 11.16s/it]\n"
]
}
],
@@ -55,16 +55,6 @@
"output_type": "stream",
"text": [
"Report:\n",
- " Requirement \\\n",
- "ID Title \n",
- "2.1 Ensure Data File Loads as Expected Ensure that data-loading functions correctly l... \n",
- "3.2 Data in the Expected Format Verify that the data to be ingested matches th... \n",
- "3.5 Check for Duplicate Records in Data Check for duplicate records in the dataset and... \n",
- "4.2 Verify Data Split Proportion Check that the data is split into training and... \n",
- "5.3 Ensure Model Output Shape Aligns with Expectation Ensure the shape of the model's output aligns ... \n",
- "6.1 Verify Evaluation Metrics Implementation Verify that the evaluation metrics are correct... \n",
- "6.2 Evaluate Model's Performance Against Thresholds Compute evaluation metrics for both the traini... \n",
- "\n",
" is_Satisfied \\\n",
"ID Title \n",
"2.1 Ensure Data File Loads as Expected 0.0 \n",
@@ -72,47 +62,27 @@
"3.5 Check for Duplicate Records in Data 0.0 \n",
"4.2 Verify Data Split Proportion 0.5 \n",
"5.3 Ensure Model Output Shape Aligns with Expectation 0.0 \n",
- "6.1 Verify Evaluation Metrics Implementation 1.0 \n",
+ "6.1 Verify Evaluation Metrics Implementation 0.5 \n",
"6.2 Evaluate Model's Performance Against Thresholds 0.5 \n",
"\n",
- " n_files_tested \\\n",
- "ID Title \n",
- "2.1 Ensure Data File Loads as Expected 3 \n",
- "3.2 Data in the Expected Format 3 \n",
- "3.5 Check for Duplicate Records in Data 3 \n",
- "4.2 Verify Data Split Proportion 3 \n",
- "5.3 Ensure Model Output Shape Aligns with Expectation 3 \n",
- "6.1 Verify Evaluation Metrics Implementation 3 \n",
- "6.2 Evaluate Model's Performance Against Thresholds 3 \n",
- "\n",
- " Observations \\\n",
- "ID Title \n",
- "2.1 Ensure Data File Loads as Expected [(test_cross_validation.py) The code does not ... \n",
- "3.2 Data in the Expected Format [(test_cross_validation.py) The code does not ... \n",
- "3.5 Check for Duplicate Records in Data [(test_cross_validation.py) The code does not ... \n",
- "4.2 Verify Data Split Proportion [(test_cross_validation.py) The code tests the... \n",
- "5.3 Ensure Model Output Shape Aligns with Expectation [(test_cross_validation.py) The code does not ... \n",
- "6.1 Verify Evaluation Metrics Implementation [(test_cross_validation.py) The code does not ... \n",
- "6.2 Evaluate Model's Performance Against Thresholds [(test_cross_validation.py) The code does not ... \n",
+ " n_files_tested \n",
+ "ID Title \n",
+ "2.1 Ensure Data File Loads as Expected 3 \n",
+ "3.2 Data in the Expected Format 3 \n",
+ "3.5 Check for Duplicate Records in Data 3 \n",
+ "4.2 Verify Data Split Proportion 3 \n",
+ "5.3 Ensure Model Output Shape Aligns with Expectation 3 \n",
+ "6.1 Verify Evaluation Metrics Implementation 3 \n",
+ "6.2 Evaluate Model's Performance Against Thresholds 3 \n",
"\n",
- " Function References \n",
- "ID Title \n",
- "2.1 Ensure Data File Loads as Expected [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "3.2 Data in the Expected Format [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "3.5 Check for Duplicate Records in Data [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "4.2 Verify Data Split Proportion [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "5.3 Ensure Model Output Shape Aligns with Expectation [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "6.1 Verify Evaluation Metrics Implementation [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "6.2 Evaluate Model's Performance Against Thresholds [{'File Path': '../../data/raw/openja/lightfm_... \n",
- "\n",
- "Score: 2.0/7\n",
+ "Score: 1.5/7\n",
"\n"
]
},
{
"data": {
"text/plain": [
- "'2.0/7'"
+ "'1.5/7'"
]
},
"execution_count": 4,
@@ -121,7 +91,7 @@
}
],
"source": [
- "parser = ResponseParser(response)\n",
+ "parser = ResponseParser(response, repo)\n",
"parser.get_completeness_score(verbose=True)"
]
},
@@ -199,7 +169,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:35<00:00, 11.94s/it]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:31<00:00, 10.40s/it]\n"
]
},
{
@@ -213,7 +183,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:32<00:00, 10.70s/it]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:31<00:00, 10.36s/it]\n"
]
},
{
@@ -228,7 +198,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:38<00:00, 12.83s/it]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:39<00:00, 13.16s/it]\n"
]
},
{
@@ -242,7 +212,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:40<00:00, 13.34s/it]\n"
+ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:42<00:00, 14.24s/it]\n"
]
}
],
@@ -306,14 +276,14 @@
" \n",
"
\n",
" 1 | \n",
- " 0.142857 | \n",
+ " 0.214286 | \n",
" ID ... | \n",
" gpt-3.5-turbo | \n",
" 2 | \n",
"
\n",
" \n",
" 2 | \n",
- " 0.714286 | \n",
+ " 0.642857 | \n",
" ID ... | \n",
" gpt-4o | \n",
" 1 | \n",
@@ -332,8 +302,8 @@
"text/plain": [
" score report model_name \\\n",
"0 0.214286 ID ... gpt-3.5-turbo \n",
- "1 0.142857 ID ... gpt-3.5-turbo \n",
- "2 0.714286 ID ... gpt-4o \n",
+ "1 0.214286 ID ... gpt-3.5-turbo \n",
+ "2 0.642857 ID ... gpt-4o \n",
"3 0.714286 ID ... gpt-4o \n",
"\n",
" test_no \n",
@@ -428,7 +398,7 @@
" Check that the data is split into training and... | \n",
" 0.5 | \n",
" 3 | \n",
- " [(test_cross_validation.py) The code does spli... | \n",
+ " [(test_cross_validation.py) The code includes ... | \n",
" [{'File Path': '../../data/raw/openja/lightfm_... | \n",
"
\n",
" \n",
@@ -488,7 +458,7 @@
"0 3 [(test_cross_validation.py) The code does not ... \n",
"1 3 [(test_cross_validation.py) The code does not ... \n",
"2 3 [(test_cross_validation.py) The code does not ... \n",
- "3 3 [(test_cross_validation.py) The code does spli... \n",
+ "3 3 [(test_cross_validation.py) The code includes ... \n",
"4 3 [(test_cross_validation.py) The code does not ... \n",
"5 3 [(test_cross_validation.py) The code does not ... \n",
"6 3 [(test_cross_validation.py) The code does not ... \n",
@@ -559,12 +529,12 @@
"
\n",
" \n",
" gpt-3.5-turbo | \n",
- " 0.002551 | \n",
+ " 0.000000 | \n",
" 2 | \n",
"
\n",
" \n",
" gpt-4o | \n",
- " 0.000000 | \n",
+ " 0.002551 | \n",
" 2 | \n",
"
\n",
" \n",
@@ -575,8 +545,8 @@
" score \n",
" var count\n",
"model_name \n",
- "gpt-3.5-turbo 0.002551 2\n",
- "gpt-4o 0.000000 2"
+ "gpt-3.5-turbo 0.000000 2\n",
+ "gpt-4o 0.002551 2"
]
},
"execution_count": 11,
@@ -597,19 +567,11 @@
"id": "5b1f94c8-1883-4435-84c7-b0687a6e6387",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/var/folders/vd/r3dvzdx10pxf47gvdqf81r9h0000gn/T/ipykernel_42405/1426530661.py:5: RuntimeWarning: divide by zero encountered in scalar divide\n",
- " f_score = score_var[('score', 'var')]['gpt-3.5-turbo'] / score_var[('score', 'var')]['gpt-4o'] # var(prev) / var(curr)\n"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "p-value: 0.0\n",
+ "p-value: 1.0\n",
"\n",
"2-tail test:\n",
" Successfully reject the null hypothesis: Var(Completeness_Score(Current Version)) == Var(Completeness_Score(Last Week Version))\n"
diff --git a/src/test_creation/modules/code_analyzer/analyzers/__init__.py b/src/test_creation/modules/code_analyzer/analyzers/__init__.py
index 7f54637..e69de29 100644
--- a/src/test_creation/modules/code_analyzer/analyzers/__init__.py
+++ b/src/test_creation/modules/code_analyzer/analyzers/__init__.py
@@ -1,7 +0,0 @@
-from abc import ABC
-
-
-class CodeAnalyzer(ABC):
- def __init__(self):
- pass
-
diff --git a/src/test_creation/modules/code_analyzer/analyzers/python.py b/src/test_creation/modules/code_analyzer/analyzers/python.py
index 43f2385..9db3c5d 100644
--- a/src/test_creation/modules/code_analyzer/analyzers/python.py
+++ b/src/test_creation/modules/code_analyzer/analyzers/python.py
@@ -1,18 +1,40 @@
+from abc import ABC, abstractmethod
import ast
+from typing import Union
+from pathlib import Path
from functools import wraps
-
-from . import CodeAnalyzer
+from collections import defaultdict
def assert_have_read_content(f):
@wraps(f)
- def decorator(*args, **kwargs):
- if args[0].content is None:
+ def decorator(self, *args, **kwargs):
+ if self.content is None:
raise RuntimeError("No content has been read yet.")
- return f(*args, **kwargs)
+ return f(self, *args, **kwargs)
+
return decorator
+class CodeAnalyzer(ABC):
+
+ @abstractmethod
+ def read(self, file_path: Union[str, Path]) -> None:
+ pass
+
+ @abstractmethod
+ def list_imported_packages(self):
+ pass
+
+ @abstractmethod
+ def list_all_functions(self):
+ pass
+
+ @abstractmethod
+ def contains_test(self):
+ pass
+
+
class PythonASTCodeAnalyzer(CodeAnalyzer):
def __init__(self):
super().__init__()
@@ -24,6 +46,14 @@ def read(self, file_path: str):
self.content = f.read()
self._tree = ast.parse(self.content)
+ @assert_have_read_content
+ def _get_function_lineno_map(self):
+ function_lineno_map = defaultdict(int)
+ for node in ast.walk(self._tree):
+ if isinstance(node, ast.FunctionDef):
+ function_lineno_map[node.name] = node.lineno
+ return function_lineno_map
+
@assert_have_read_content
def list_imported_packages(self):
packages = set()
@@ -36,7 +66,7 @@ def list_imported_packages(self):
@assert_have_read_content
def list_all_functions(self):
- raise NotImplementedError()
+ return self._get_function_lineno_map().keys()
@assert_have_read_content
def contains_test(self):
@@ -60,6 +90,15 @@ def read(self, file_path: str):
with open(file_path, 'r') as f:
self.content = f.readlines()
+ @assert_have_read_content
+ def _get_function_lineno_map(self):
+ function_lineno_map = defaultdict(int)
+ for line_num, line in enumerate(self.content):
+ if line.lstrip().startswith('def '):
+ func_name = line.lstrip().split('(')[0].split(' ')[1]
+ function_lineno_map[func_name] = line_num + 1 # line starts with 1
+ return function_lineno_map
+
@assert_have_read_content
def list_imported_packages(self):
packages = set()
diff --git a/src/test_creation/modules/code_analyzer/git.py b/src/test_creation/modules/code_analyzer/git.py
new file mode 100644
index 0000000..8a43010
--- /dev/null
+++ b/src/test_creation/modules/code_analyzer/git.py
@@ -0,0 +1,76 @@
+import re
+from typing import Union, Optional
+from pathlib import Path
+from copy import copy
+
+from git import Repo
+
+
+class GitContext:
+ def __init__(self, git_dir: Union[str, Path]):
+ self.git_dir = Path(git_dir)
+ self.git_repo = Repo(self.git_dir)
+
+ self.branch = self._get_current_branch()
+ self.host, self.org, self.repo_name = self._get_remote_info()
+
+ self.remote_link_format_map = {
+ "github": "{host}/{org}/{repo}/blob/{branch}/{path}#L{line_num}",
+ "gitlab": "{host}/{org}/{repo}/blob/{branch}/{path}#L{line_num}",
+ "bitbucket": "{host}/{org}/{repo}/src/{branch}/{path}#lines-{"
+ "line_num}",
+ "gitee": "{host}/{org}/{repo}/blob/{branch}{path}#L{line_num}"
+ }
+ self.remote_protocol = "https"
+ self.remote_service_family = self.__get_remote_service_family()
+
+ def __get_remote_service_family(self):
+ result = None
+ if self.host:
+ hits = [key for key in self.remote_link_format_map.keys() if
+ key in self.host]
+ if hits:
+ result = hits[0]
+ return result
+
+ def _get_current_branch(self):
+ if self.git_repo.head.is_detached:
+ return self.git_repo.head.commit.hexsha
+ else:
+ return self.git_repo.active_branch.name
+
+ def _get_remote_info(self) -> tuple[Optional[str], Optional[str], str]:
+ if self.git_repo.remotes:
+ if 'origin' in [r.name for r in self.git_repo.remotes]:
+ remote = self.git_repo.remote()
+ else:
+ remote = self.git_repo.remotes[0]
+ remote_url = remote.url
+ # git urls:
+ # https://git-scm.com/docs/git-clone#URLS
+ pattern = r"(?:\w+:\/\/)?(?:\w+@)?(.+)[\/:](.+)\/([^\.]+)(?:\.git)?"
+ host, org, repo_name = re.search(pattern, remote_url).groups()
+ return host, org, repo_name
+ else:
+ print("This git repository has no remote")
+ return None, None, "."
+
+ def construct_remote_link_to_file(self, file_path: Union[str, Path],
+ line_num: Optional[int] = None) -> str:
+ path = Path(file_path)
+ if path.is_absolute():
+ rel_path = path.relative_to(self.git_dir)
+ else:
+ rel_path = path
+ if self.remote_service_family:
+ f_str = copy(self.remote_link_format_map[self.remote_service_family])
+ if line_num is None:
+ f_str = f_str.split("#")[0]
+ injected_str = f"{self.remote_protocol}://" + \
+ f_str.format(host=self.host, org=self.org, repo=self.repo_name,
+ branch=self.branch, path=rel_path,
+ line_num=line_num)
+ return injected_str
+ else:
+ print("No matching service. Using local link instead...")
+ return f"file://{str(self.git_dir)}/{rel_path}"
diff --git a/src/test_creation/modules/code_analyzer/repo.py b/src/test_creation/modules/code_analyzer/repo.py
index c060f25..bf9c50e 100644
--- a/src/test_creation/modules/code_analyzer/repo.py
+++ b/src/test_creation/modules/code_analyzer/repo.py
@@ -1,16 +1,30 @@
import os
import logging
+from functools import wraps
from pathlib import Path
from collections import defaultdict
-from typing import Dict, List
+from typing import Optional
from .analyzers.python import PythonNaiveCodeAnalyzer, PythonASTCodeAnalyzer
+from .git import GitContext
logger = logging.getLogger("test-creation.repo")
+def requires_git_context(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ """wrapper function to check if we have git context."""
+ if self.git_context is None:
+ raise RuntimeError("This repository has no git context.")
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
class Repository:
def __init__(self, path: str):
+
if not os.path.exists(path):
raise FileNotFoundError(f"Repository {path} does not exist.")
elif os.path.isfile(path):
@@ -19,6 +33,10 @@ def __init__(self, path: str):
if not os.path.exists(self.path / ".git"):
# TODO: to be converted to use logger
print("Warning: The repository is not a git repository.")
+ self.git_context = None
+ else:
+ self.git_context = GitContext(self.path)
+
self.files = []
self.fileext_language_map = {
'.js': 'JavaScript',
@@ -33,6 +51,13 @@ def __init__(self, path: str):
'.c': 'C'
}
self.lf_map = self._get_language_file_map()
+ self.ffl_map = self._get_file_function_lineno_map()
+
+ @requires_git_context
+ def get_git_direct_link(self, file: str,
+ lineno: Optional[int] = None) -> str:
+ return self.git_context.construct_remote_link_to_file(file,
+ line_num=lineno)
def _get_all_files(self, include_git_dir: bool = False):
file_paths = []
@@ -46,14 +71,31 @@ def _get_all_files(self, include_git_dir: bool = False):
file_paths.append(f'{root}/{file}')
return file_paths
- def _get_language_file_map(self):
- file_language_map = defaultdict(list)
+ def _get_language_file_map(self) -> dict[str, list[str]]:
+ language_file_map = defaultdict(list)
files = self._get_all_files()
for file in files:
for k, v in self.fileext_language_map.items():
if file.endswith(k):
- file_language_map[v].append(file)
- return file_language_map
+ language_file_map[v].append(file)
+ return language_file_map
+
+ def _get_file_function_lineno_map(self) -> dict[str, dict[str, list[str]]]:
+ file_function_lineno_map = defaultdict(lambda: defaultdict(int))
+ for lang, files in self.lf_map.items():
+ # TODO: only Python is supported now
+ if lang == "Python":
+ ast = PythonASTCodeAnalyzer()
+ naive = PythonNaiveCodeAnalyzer()
+ for file in files:
+ try:
+ ast.read(file)
+ file_function_lineno_map[lang][file] = ast._get_function_lineno_map()
+ except Exception as e:
+ logger.info("Exception occurred when parsing using ast (Python 2 code?) Using naive parser...")
+ naive.read(file)
+ file_function_lineno_map[lang][file] = naive._get_function_lineno_map()
+ return file_function_lineno_map
def list_languages(self):
return list(self.lf_map.keys())
@@ -70,7 +112,7 @@ def list_packages(self):
packages = list(set(packages))
return packages
- def list_test_files(self) -> Dict[str, List[str]]:
+ def list_test_files(self) -> dict[str, list[str]]:
testfiles = defaultdict(list)
# for now only Python is supported
files = self.lf_map.get("Python", [])
diff --git a/src/test_creation/modules/mixins.py b/src/test_creation/modules/mixins.py
index 792cfbe..bd51df3 100644
--- a/src/test_creation/modules/mixins.py
+++ b/src/test_creation/modules/mixins.py
@@ -21,10 +21,16 @@ def _filedump_check(self, output_path: str, exist_ok: bool, expects_directory_if
"provided a flag/argument for "
"file overwriting?)")
elif os.path.exists(normalized_path):
- if expects_directory_if_exists and not os.path.isdir(normalized_path):
- raise NotADirectoryError("An non-directory already exists in the path but the write operation is expecting to overwrite a directory.")
- elif not expects_directory_if_exists and not os.path.isfile(normalized_path):
- raise IsADirectoryError("An non-file object already exists in the path but the write operation is expecting to overwrite a file.")
+ if expects_directory_if_exists and not os.path.isdir(
+ normalized_path):
+ raise NotADirectoryError("An non-directory already exists in "
+ "the path but the write operation is"
+ " expecting to overwrite a directory.")
+ elif not expects_directory_if_exists and not os.path.isfile(
+ normalized_path):
+ raise IsADirectoryError("An non-file object already exists in "
+ "the path but the write operation is "
+ "expecting to overwrite a file.")
if not os.access(normalized_path, os.W_OK):
raise PermissionError(f"Write permission is not granted for the output path: {normalized_path}")
diff --git a/src/test_creation/modules/workflow/evaluator.py b/src/test_creation/modules/workflow/evaluator.py
index 3f4088d..88920d4 100644
--- a/src/test_creation/modules/workflow/evaluator.py
+++ b/src/test_creation/modules/workflow/evaluator.py
@@ -8,7 +8,6 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
from langchain_core.language_models import LanguageModelLike
from langchain_core.documents import Document
-from pydantic import ValidationError
from ..checklist.checklist import Checklist
from ..code_analyzer.repo import Repository
@@ -30,22 +29,24 @@ class TestEvaluator(PipelineRunner, ABC):
"""Abstract base class for test evaluators
i.e. class object to run evaluation of test files from a given repository.
"""
- def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, repository: Repository,
- checklist: Checklist):
+
+ def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat,
+ repository: Repository, checklist: Checklist):
self.llm = llm
self.checklist = checklist
self.repository = repository
self.prompt_format = prompt_format
- self.test_items = None
+ self._test_items = None
self.chain = self.prompt_format.prompt | self.llm | self.prompt_format.parser
class PerFileTestEvaluator(TestEvaluator):
"""Concrete test evaluator that performs per-file evaluation."""
- def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, repository: Repository,
- checklist: Checklist, retries: int = 3):
+
+ def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat,
+ repository: Repository, checklist: Checklist, retries: int = 3):
super().__init__(llm, prompt_format, repository, checklist)
self.retries = retries
@@ -53,7 +54,8 @@ def __init__(self, llm: LanguageModelLike, prompt_format: PromptFormat, reposito
if not self._files:
print("File loader returned no files!")
- self._test_items = self.checklist.get_all_tests(['ID', 'Title', 'Requirement'])
+ self._test_items = self.checklist.get_all_tests(['ID', 'Title',
+ 'Requirement'])
if not self._test_items:
print("Loaded checklist successfully, but it contains no test items!")
@@ -67,15 +69,18 @@ def _load_test_file_into_splits(file_path: str) -> List[Document]:
def _validate_response(self, raw_response: dict) -> None:
"""Validation logics that are not covered by pydantic or langchain."""
- # ensures the number of items in the response is the same as provided checklists
- if len(raw_response['results']) != len(self.test_items):
- raise ValidationError("Number of items returned from LLM does not match that in checklist.")
+ # ensures the number of items in the response is the same as provided
+ # checklists
+ if len(raw_response['results']) != len(self._test_items):
+ raise AssertionError("Number of items returned from LLM does not match that in checklist.")
+ if not all(['Functions' in item for item in raw_response['results']]):
+ raise AssertionError("Not all items returned contain the attribute `Functions`.")
def run(self, verbose: bool = False) -> EvaluationResponse:
eval_response = EvaluationResponse(
model={'name': self.llm.model_name, 'temperature': self.llm.temperature},
- repository_path=self.repository.path,
- checklist_path=self.checklist.path
+ repository={'path': self.repository.path, 'object': self.repository},
+ checklist={'path': self.checklist.path, 'object': self.checklist}
)
for fp in tqdm(self._files):
@@ -84,7 +89,6 @@ def run(self, verbose: bool = False) -> EvaluationResponse:
splits = self._load_test_file_into_splits(fp)
if verbose:
print(f"# splits: {len(self._files)}")
- # FIXME: it sometimes tests only part of the checklist items
response = None
retry_count = 0
@@ -107,8 +111,11 @@ def run(self, verbose: bool = False) -> EvaluationResponse:
self._validate_response(response)
except Exception as e:
+ if verbose:
+ print(f"error occurred: {e.__class__.__name__} - {str(e)}")
errors.append({'name': e.__class__.__name__, 'description': str(e)})
retry_count += 1
+ response = None
continue
if not response:
diff --git a/src/test_creation/modules/workflow/parse.py b/src/test_creation/modules/workflow/parse.py
index 6272d65..e8e9b1e 100644
--- a/src/test_creation/modules/workflow/parse.py
+++ b/src/test_creation/modules/workflow/parse.py
@@ -10,11 +10,36 @@
class ResponseParser(ExportableMixin):
def __init__(self, response: EvaluationResponse):
+ # FIXME: respository is required to extract the line numbers for functions
+ # I added an optional argument "respository" here, can't think of any better way to handle it yet
super().__init__()
self.response = response
self.evaluation_report = None
+ self.repository = self.response.repository.object
+ self.git_context = self.repository.git_context
+ self.items = []
+
+ def _parse_items(self):
+ items = []
+ for result in self.response.call_results:
+ response = result.parsed_response['results']
+ for item in response:
+ fp = result.files_evaluated[0]
+ item['File Path'] = fp
+ if self.repository:
+ item['lineno'] = [self.repository.ffl_map['Python'][fp][func] for func in item['Functions']]
+ else:
+ item['lineno'] = []
+ item['Referenced Functions'] = [
+ f"[{func}]({self.repository.get_git_direct_link(fp, lineno)})"
+ for func, lineno in zip(item['Functions'], item['lineno'])
+ ]
+ items.append(item)
+ self.items = items
+ return items
def get_completeness_score(self, score_format: str = 'fraction', verbose: bool = False) -> Optional[float]:
+ """Compute Evaluation Report and Completeness Score."""
# TODO: change this after putting the logic to load data from JSON file
# instead of from a Python object.
@@ -30,16 +55,10 @@ def get_completeness_score(self, score_format: str = 'fraction', verbose: bool =
print("failed to obtain valid response, cannot calculate completeness score")
return None
- report = []
- for result in self.response.call_results:
- response = result.parsed_response['results']
- for item in response:
- item['file'] = result.files_evaluated[0] # FIXME: it might fail if the evaluation is on multiple files
- report.append(item)
+ items = self._parse_items()
- report_df = pd.DataFrame(report)
- report_df = report_df.rename(columns={"file": "File Path"})
- report_df['Function References'] = report_df[['File Path', 'Functions']].to_dict(orient='records')
+ report_df = pd.DataFrame(items)
+ report_df['Function References'] = report_df[['File Path', 'Referenced Functions']].to_dict(orient='records')
report_df['Observation'] = '(' + report_df['File Path'].apply(lambda x: os.path.split(x)[-1]) + ') ' + \
report_df['Observation']
report_df = report_df.groupby(['ID', 'Title']).agg({
@@ -58,7 +77,7 @@ def get_completeness_score(self, score_format: str = 'fraction', verbose: bool =
if verbose:
print("Report:")
- print(report_df)
+ print(report_df[['is_Satisfied', 'n_files_tested']])
print()
print(f'Score: {score}')
print()
diff --git a/src/test_creation/modules/workflow/prompt_format.py b/src/test_creation/modules/workflow/prompt_format.py
index 57ea354..205c161 100644
--- a/src/test_creation/modules/workflow/prompt_format.py
+++ b/src/test_creation/modules/workflow/prompt_format.py
@@ -25,7 +25,7 @@ class TestItemEvaluation(BaseModel):
Title: str = Field(description="The corresponding `Title` of the checklist item provided")
Requirement: str = Field(description="The corresponding `Requirement` of the checklist item provided")
Observation: str = Field(description="Your detailed observation of the code in accordance to the given checklist item")
- Functions: List[str] = Field(description="Test functions that satisfy the given requirement (if any)")
+ Functions: List[str] = Field(description="Test functions that satisfy the given requirement. If no function satisfies, an empty list i.e. [] should be returned.")
Evaluation: str = Field(description="The summarized evaluation. Must be one of Satisfied/Partially Satisfied/Not Satisfied.")
Score: float = Field(description="The score obtained from the given evaluation (1 for Satisfied / 0.5 for Partially Satisfied / 0 for Not Satisfied)")
diff --git a/src/test_creation/modules/workflow/response.py b/src/test_creation/modules/workflow/response.py
index 386bd95..07643bf 100644
--- a/src/test_creation/modules/workflow/response.py
+++ b/src/test_creation/modules/workflow/response.py
@@ -2,14 +2,31 @@
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, ConfigDict
+from ..code_analyzer.repo import Repository
+from ..checklist.checklist import Checklist
-class LLM(BaseModel):
+
+class LLMInfo(BaseModel):
name: str = Field(description="Name of the LLM used")
temperature: float = Field(description="Temperature of the LLM")
+class RepositoryInfo(BaseModel):
+ path: Union[str, Path] = Field(description="Path of the repository")
+ object: Repository = Field(description="Repository object")
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+
+class ChecklistInfo(BaseModel):
+ path: Union[str, Path] = Field(description="Path of the checklist")
+ object: Checklist = Field(description="Checklist object")
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+
class Error(BaseModel):
name: str = Field(description="Class name of the error")
description: str = Field(description="Description of the error")
@@ -35,8 +52,14 @@ class EvaluationResponse(BaseModel):
name
temperature
}
- repository_path
- checklist_path
+ repository {
+ path
+ object
+ }
+ checklist {
+ path
+ object
+ }
call_results [{
start_time
end_time
@@ -52,7 +75,7 @@ class EvaluationResponse(BaseModel):
}]
}
"""
- model: LLM = Field(description="LLM-related information")
- repository_path: Union[str, Path] = Field(description="Repository path")
- checklist_path: Union[str, Path] = Field(description="Checklist path")
+ model: LLMInfo = Field(description="LLM-related information")
+ repository: RepositoryInfo = Field(description="Repository-related information")
+ checklist: ChecklistInfo = Field(description="Checklist-related information")
call_results: List[CallResult] = Field(description="List of call results", default=[])
diff --git a/tests/test_repo.py b/tests/test_repo.py
new file mode 100644
index 0000000..fb29a73
--- /dev/null
+++ b/tests/test_repo.py
@@ -0,0 +1,91 @@
+from pathlib import Path
+
+import pytest
+from test_creation.modules.code_analyzer import repo as r
+from test_creation.modules.code_analyzer.git import GitContext
+
+
+@pytest.fixture()
+def test_git_repo(git_repo):
+ # The fixture derives from `workspace` in `pytest-shutil`, so they contain
+ # a handle to the path `path` object (see https://path.readthedocs.io/)
+ path = git_repo.workspace
+ txt_file = path / 'hello.txt'
+ txt_file.write_text('hello world!')
+
+ py_file = Path(path / 'src/python/main.py')
+ py_file.parent.mkdir(parents=True, exist_ok=True)
+ py_file.write_text('print("hello world!")')
+
+ # We can run commands relative to the working directory
+ git_repo.run('git add .')
+
+ # It's better to use the GitPython api directly - the 'api' attribute is
+ # a handle to the repository object.
+ git_repo.api.index.commit("Initial commit")
+
+ # The fixture has a URI property you can use in downstream systems
+ assert git_repo.uri.startswith('file://')
+
+ return git_repo
+
+
+################################################################################
+# Repository #
+################################################################################
+def test_repository_should_be_able_to_read_git_repo(test_git_repo):
+ path = test_git_repo.workspace
+ repo = r.Repository(path)
+ assert any(['src/python/main.py' in file for file in repo._get_all_files()])
+
+
+################################################################################
+# GitContext #
+################################################################################
+@pytest.mark.parametrize(
+ "fixture_name, remote_name, remote_url, expected",
+ [
+ (
+ "test_git_repo",
+ "origin",
+ "git@github.internal.com:UBC-MDS/testing-repo_1234.git",
+ ("github.internal.com", "UBC-MDS", "testing-repo_1234")
+ ),
+ (
+ "test_git_repo",
+ "export",
+ "ssh://git@github.internal.com:UBC-MDS/testing-repo_1234.git",
+ ("github.internal.com", "UBC-MDS", "testing-repo_1234")
+ ),
+ (
+ "test_git_repo",
+ "internal",
+ "https://github.com:8080/UBC-MDS/test-creation.git",
+ ("github.com:8080", "UBC-MDS", "test-creation")
+ ),
+ (
+ "test_git_repo",
+ "origin",
+ "http://gitlab.example.com:8080/UBC-MDS/test-creation.git",
+ ("gitlab.example.com:8080", "UBC-MDS", "test-creation")
+ ),
+ (
+ "test_git_repo",
+ "a",
+ "ftp://github.com/SoloSynth1/Solosynth1",
+ ("github.com", "SoloSynth1", "Solosynth1")
+ ),
+ ]
+)
+def test_git_context_can_extract_remote_git_urls(fixture_name, remote_name,
+ remote_url, expected, request):
+ repo = request.getfixturevalue(fixture_name)
+ repo.api.create_remote(remote_name, remote_url)
+ gc = GitContext(repo.workspace)
+ assert (gc.host, gc.org, gc.repo_name) == expected
+
+
+def test_git_context_gives_out_local_link_when_no_remote(test_git_repo):
+ context = GitContext(test_git_repo.workspace)
+ link = context.construct_remote_link_to_file("src/python/main.py")
+ assert link == f"file://{test_git_repo.workspace}/src/python/main.py"
\ No newline at end of file
diff --git a/tests/test_test_creation.py b/tests/test_test_creation.py
deleted file mode 100644
index 8b50deb..0000000
--- a/tests/test_test_creation.py
+++ /dev/null
@@ -1 +0,0 @@
-#from test_creation import test_creation
diff --git a/tests/test_utils.py b/tests/test_utils.py
index a978032..cdda766 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,31 +1,21 @@
from pathlib import Path
+import pytest
from test_creation.modules.utils import get_extension
-def test_extension_from_string_can_be_extracted_correctly():
- correct_path = "checklist/checklist.csv"
- assert get_extension(correct_path) == "csv"
-
-
-def test_extension_from_path_can_be_extracted_correctly():
- correct_path = Path("checklist/checklist.csv")
- assert get_extension(correct_path) == "csv"
-
-
-def test_extension_extracted_is_all_lower_cased():
- path = "ALL/CAPITAL/PATH/TEST.ZIP"
- assert get_extension(path) == "zip"
-
-
-def test_only_last_extension_will_be_extracted():
- path = "test/multi_ext.tar.gz"
- assert get_extension(path) == "gz"
-
-
-def test_file_with_no_extension_will_produce_empty_string():
- path = "test/README"
- assert get_extension(path) == ""
+@pytest.mark.parametrize(
+ "path, expected",
+ [
+ ("checklist/checklist.csv", "csv"),
+ (Path("checklist/checklist.csv"), "csv"),
+ ("ALL/CAPITAL/PATH/TEST.ZIP", "zip"),
+ ("test/multi_ext.tar.gz", "gz"),
+ (Path("test/README"), "")
+ ]
+)
+def test_extension_from_string_can_be_extracted_correctly(path, expected):
+ assert get_extension(path) == expected
def test_extracted_extension_does_not_start_with_dot():