diff --git a/src/repo_gpt/file_handler/abstract_handler.py b/src/repo_gpt/file_handler/abstract_handler.py index a008ae2..5e84b38 100644 --- a/src/repo_gpt/file_handler/abstract_handler.py +++ b/src/repo_gpt/file_handler/abstract_handler.py @@ -10,11 +10,12 @@ class CodeType(Enum): FUNCTION = "function" CLASS = "class" METHOD = "method" + GLOBAL = "global" @dataclass class ParsedCode: - name: str + name: Union[str, None] code_type: CodeType code: str summary: Union[str, None] @@ -24,7 +25,7 @@ class ParsedCode: file_checksum: str = None def __lt__(self, other: "ParsedCode"): - return self.name < other.name + return self.code < other.code class AbstractHandler(ABC): diff --git a/src/repo_gpt/file_handler/generic_code_file_handler.py b/src/repo_gpt/file_handler/generic_code_file_handler.py index 104b0dd..029ffc4 100644 --- a/src/repo_gpt/file_handler/generic_code_file_handler.py +++ b/src/repo_gpt/file_handler/generic_code_file_handler.py @@ -57,21 +57,37 @@ def extract_code(self, filepath: Path) -> List[ParsedCode]: try: code = source_code.read() tree = self.parser.parse(bytes(code, "utf8")) - return self.parse_tree(tree, filepath) + return self.parse_tree(tree) except Exception as e: print(f"Failed to parse file {filepath}: {e}") raise - def parse_tree(self, tree, filepath: str) -> List[ParsedCode]: + def parse_tree(self, tree) -> List[ParsedCode]: parsed_nodes = [] root_node = tree.root_node + global_nodes = [] for node in root_node.children: if node.type == self.function_node_type: parsed_nodes.append(self.get_function_parsed_code(node)) elif node.type == self.class_node_type: parsed_nodes.extend(self.get_class_and_method_parsed_code(node)) + else: + global_nodes.append(node) + if len(global_nodes) > 0: + parsed_nodes.append(self.get_global_code(global_nodes)) return parsed_nodes + def get_global_code(self, global_nodes: []) -> ParsedCode: + code = "\n".join([node.text.decode("utf8") for node in global_nodes]) + return ParsedCode( + name=None, + code_type=CodeType.GLOBAL, + code=code, + summary=None, + inputs=None, + outputs=None, + ) + def get_function_parsed_code(self, function_node, is_method=False) -> ParsedCode: name = self.get_function_name(function_node) input_params, output_params = self.get_function_parameters(function_node) diff --git a/test/file_handler/test_php_extract_code.py b/test/file_handler/test_php_extract_code.py index fee31db..eb70051 100644 --- a/test/file_handler/test_php_extract_code.py +++ b/test/file_handler/test_php_extract_code.py @@ -36,6 +36,14 @@ class TestClass extends BaseClass { summary=None, outputs=("string",), ), + ParsedCode( + name=None, + code_type=CodeType.GLOBAL, + code="", + inputs=None, + summary=None, + outputs=None, + ), ] EXPECTED_PHP_CLASS_PARSED_CODE = [ @@ -55,6 +63,14 @@ class TestClass extends BaseClass { summary=None, outputs=None, ), + ParsedCode( + name=None, + code_type=CodeType.GLOBAL, + code="", + inputs=None, + summary=None, + outputs=None, + ), ] @@ -73,8 +89,8 @@ def test_php_normal_operation(tmp_path, input_text, expected_output): assert isinstance(parsed_code, list) assert all(isinstance(code, ParsedCode) for code in parsed_code) - parsed_code.sort(key=lambda x: x.name) - expected_output.sort(key=lambda x: x.name) + parsed_code.sort() + expected_output.sort() assert len(parsed_code) == len(expected_output) assert parsed_code == expected_output @@ -82,16 +98,15 @@ def test_php_normal_operation(tmp_path, input_text, expected_output): def test_no_function_in_file(tmp_path): # Test PHP file with no functions or classes p = tmp_path / "no_function_class_php_file.php" - p.write_text( - """ - $x = 10; - $y = 20; - $z = $x + $y; - """ - ) + code = """$x = 10; +$y = 20; +$z = $x + $y;""" + p.write_text(code) parsed_code = handler.extract_code(p) assert isinstance(parsed_code, list) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code_type == CodeType.GLOBAL + assert parsed_code[0].code == code.strip() def test_edge_cases(tmp_path): @@ -103,10 +118,15 @@ def test_edge_cases(tmp_path): assert len(parsed_code) == 0 # Test non-PHP file - p = tmp_path / "non_php_file.txt" - p.write_text("This is a text file, not a PHP file.") + p = ( + tmp_path / "non_php_file.txt" + ) # the function doesn't check file types just parses text code + text = "This is a text file, not a PHP file." + p.write_text(text) parsed_code = handler.extract_code(p) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code_type == CodeType.GLOBAL + assert parsed_code[0].code == text # Test non-existent file p = tmp_path / "non_existent_file.php" diff --git a/test/file_handler/test_python_extract_code.py b/test/file_handler/test_python_extract_code.py index 56eca07..155a842 100644 --- a/test/file_handler/test_python_extract_code.py +++ b/test/file_handler/test_python_extract_code.py @@ -7,6 +7,8 @@ # Define input text SAMPLE_FUNCTION_INPUT_TEXT = """ +foo = "bar" + def hello_world() -> str: return "Hello, world!" """ @@ -27,6 +29,14 @@ def test_method(self): summary=None, outputs=("str",), ), + ParsedCode( + name=None, + code_type=CodeType.GLOBAL, + code='foo = "bar"', + inputs=None, + summary=None, + outputs=None, + ), ] EXPECTED_CLASS_PARSED_CODE = [ @@ -73,16 +83,16 @@ def test_normal_operation(tmp_path, input_text, expected_output): def test_no_function_in_file(tmp_path): # Test Python file with no functions or classes p = tmp_path / "no_function_class_python_file.py" - p.write_text( - """ - x = 10 - y = 20 - z = x + y + code = """x = 10 +y = 20 +z = x + y """ - ) + p.write_text(code) parsed_code = handler.extract_code(p) assert isinstance(parsed_code, list) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code_type == CodeType.GLOBAL + assert parsed_code[0].code == code.strip() def test_edge_cases(tmp_path): @@ -94,10 +104,15 @@ def test_edge_cases(tmp_path): assert len(parsed_code) == 0 # Test non-Python file - p = tmp_path / "non_python_file.txt" - p.write_text("This is a text file, not a Python file.") + p = ( + tmp_path / "non_python_file.txt" + ) # This function doesn't check if the file or function is valid Python + text = "This is a text file, not a Python file." + p.write_text(text) parsed_code = handler.extract_code(p) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code_type == CodeType.GLOBAL + assert parsed_code[0].code == text # Test non-existent file p = tmp_path / "non_existent_file.py" diff --git a/test/file_handler/test_typescript_extract_code.py b/test/file_handler/test_typescript_extract_code.py index c6a32e6..2fc405e 100644 --- a/test/file_handler/test_typescript_extract_code.py +++ b/test/file_handler/test_typescript_extract_code.py @@ -70,24 +70,22 @@ def test_ts_normal_operation(tmp_path, input_text, expected_output): assert isinstance(parsed_code, list) assert all(isinstance(code, ParsedCode) for code in parsed_code) assert len(parsed_code) == len(expected_output) - parsed_code.sort(key=lambda x: x.name) - expected_output.sort(key=lambda x: x.name) + parsed_code.sort() + expected_output.sort() assert parsed_code == expected_output def test_no_function_in_file(tmp_path): # Test TypeScript file with no functions or classes p = tmp_path / "no_function_class_ts_file.ts" - p.write_text( - """ - let x = 10; - let y = 20; - let z = x + y; - """ - ) + code = """let x = 10; +let y = 20; +let z = x + y;""" + p.write_text(code) parsed_code = handler.extract_code(p) assert isinstance(parsed_code, list) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code == code.strip() def test_edge_cases(tmp_path): @@ -99,10 +97,15 @@ def test_edge_cases(tmp_path): assert len(parsed_code) == 0 # Test non-TypeScript file - p = tmp_path / "non_ts_file.txt" - p.write_text("This is a text file, not a TypeScript file.") + p = ( + tmp_path / "non_ts_file.txt" + ) # This fucntion doesn't check file types only parses the text / code + text = "This is a text file, not a TypeScript file." + p.write_text(text) parsed_code = handler.extract_code(p) - assert len(parsed_code) == 0 + assert len(parsed_code) == 1 + assert parsed_code[0].code == text + assert parsed_code[0].code_type == CodeType.GLOBAL # Test non-existent file p = tmp_path / "non_existent_file.ts"