From 143519abaee85008e7518f7c48fcea6f49f7eadc Mon Sep 17 00:00:00 2001 From: Kostiantyn Sharovarskyi Date: Tue, 14 Jan 2025 14:12:10 +0200 Subject: [PATCH] feat: Add csharp support to repo_map (#1073) --- Build.ps1 | 2 + Cargo.lock | 11 ++ crates/avante-repo-map/Cargo.toml | 1 + .../queries/tree-sitter-c-sharp-defs.scm | 25 +++ crates/avante-repo-map/src/lib.rs | 150 ++++++++++++++++++ lua/avante/repo_map.lua | 1 + 6 files changed, 190 insertions(+) create mode 100644 crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm diff --git a/Build.ps1 b/Build.ps1 index f7f7a2758..5f8710c03 100644 --- a/Build.ps1 +++ b/Build.ps1 @@ -20,8 +20,10 @@ function Build-FromSource($feature) { $SCRIPT_DIR = $PSScriptRoot $targetTokenizerFile = "avante_tokenizers.dll" $targetTemplatesFile = "avante_templates.dll" + $targetRepoMapFile = "avante_repo_map.dll" Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_tokenizers.dll") (Join-Path $BuildDir $targetTokenizerFile) Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_templates.dll") (Join-Path $BuildDir $targetTemplatesFile) + Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_repo_map.dll") (Join-Path $BuildDir $targetRepoMapFile) Remove-Item -Recurse -Force "target" } diff --git a/Cargo.lock b/Cargo.lock index 4e9ab8ce1..262b14673 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,6 +39,7 @@ dependencies = [ "serde", "tree-sitter", "tree-sitter-c", + "tree-sitter-c-sharp", "tree-sitter-cpp", "tree-sitter-elixir", "tree-sitter-go", @@ -1331,6 +1332,16 @@ dependencies = [ "tree-sitter-language", ] +[[package]] +name = "tree-sitter-c-sharp" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67f06accca7b45351758663b8215089e643d53bd9a660ce0349314263737fcb0" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "tree-sitter-cpp" version = "0.23.1" diff --git a/crates/avante-repo-map/Cargo.toml b/crates/avante-repo-map/Cargo.toml index 001c0733e..2494275ef 100644 --- a/crates/avante-repo-map/Cargo.toml +++ b/crates/avante-repo-map/Cargo.toml @@ -29,6 +29,7 @@ tree-sitter-ruby = "0.23" tree-sitter-zig = "1.0.2" tree-sitter-scala = "0.23" tree-sitter-elixir = "0.3.1" +tree-sitter-c-sharp = "0.23" [lints] workspace = true diff --git a/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm b/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm new file mode 100644 index 000000000..84afcb89a --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm @@ -0,0 +1,25 @@ +(class_declaration + name: (identifier) @class + (parameter_list)? @method) ;; Primary constructor + +(record_declaration + name: (identifier) @class + (parameter_list)? @method) ;; Primary constructor + +(interface_declaration + name: (identifier) @class) + +(method_declaration) @method + +(constructor_declaration) @method + +(property_declaration) @class_variable + +(field_declaration + (variable_declaration + (variable_declarator))) @class_variable + +(enum_declaration + body: (enum_member_declaration_list + (enum_member_declaration) @enum_item)) + diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index 1f30b260e..fd0faeb08 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -63,6 +63,7 @@ fn get_ts_language(language: &str) -> Option { "zig" => Some(tree_sitter_zig::LANGUAGE), "scala" => Some(tree_sitter_scala::LANGUAGE), "elixir" => Some(tree_sitter_elixir::LANGUAGE), + "csharp" => Some(tree_sitter_c_sharp::LANGUAGE), _ => None, } } @@ -79,6 +80,7 @@ const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-d const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm"); const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm"); +const CSHARP_QUERY: &str = include_str!("../queries/tree-sitter-c-sharp-defs.scm"); fn get_definitions_query(language: &str) -> Result { let ts_language = get_ts_language(language); @@ -99,6 +101,7 @@ fn get_definitions_query(language: &str) -> Result { "ruby" => RUBY_QUERY, "scala" => SCALA_QUERY, "elixir" => ELIXIR_QUERY, + "csharp" => CSHARP_QUERY, _ => return Err(format!("Unsupported language: {language}")), }; let query = Query::new(&ts_language.into(), contents) @@ -129,6 +132,20 @@ fn find_ancestor_by_type<'a>(node: &'a Node, parent_type: &str) -> Option( + node: &'a Node, + possible_parent_types: &[&str], +) -> Option> { + let mut parent = node.parent(); + while let Some(parent_node) = parent { + if possible_parent_types.contains(&parent_node.kind()) { + return Some(parent_node); + } + parent = parent_node.parent(); + } + None +} + fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option> { let mut cursor = node.walk(); for i in 0..node.descendant_count() { @@ -189,6 +206,17 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option bool { + node.kind() == "parameter_list" + && node.parent().map_or(false, |n| { + n.kind() == "class_declaration" || n.kind() == "record_declaration" + }) +} + +fn csharp_find_parent_type_node<'a>(node: &'a Node) -> Option> { + find_first_ancestor_by_types(node, &["class_declaration", "record_declaration"]) +} + fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8]) -> Option { let mut parent = node.parent(); while let Some(parent_node) = parent { @@ -351,6 +379,22 @@ fn extract_definitions(language: &str, source: &str) -> Result, .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or(node_text) .to_string(), + "csharp" => { + let mut identifier = node; + // Handle primary constructors (they are direct children of *_declaration) + if *capture_name == "method" && csharp_is_primary_constructor(&node) { + identifier = node.parent().unwrap_or(node); + } else if *capture_name == "class_variable" { + identifier = + find_descendant_by_type(&node, "variable_declarator").unwrap_or(node); + } + + identifier + .child_by_field_name("name") + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(node_text) + .to_string() + } _ => node .child_by_field_name("name") .map(|n| n.utf8_text(source.as_bytes()).unwrap()) @@ -472,6 +516,23 @@ fn extract_definitions(language: &str, source: &str) -> Result, if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) { continue; } + + if language == "csharp" { + let csharp_visibility = find_descendant_by_type(&node, "modifier"); + if csharp_visibility.is_none() && !csharp_is_primary_constructor(&node) { + continue; + } + if csharp_visibility.is_some() { + let csharp_visibility_text = csharp_visibility + .unwrap() + .utf8_text(source.as_bytes()) + .unwrap(); + if csharp_visibility_text == "private" { + continue; + } + } + } + let mut params_node = node .child_by_field_name("parameters") .or_else(|| find_descendant_by_type(&node, "parameter_list")); @@ -494,6 +555,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, .unwrap_or("()"); let mut return_type_node = match language { "cpp" => node.child_by_field_name("type"), + "csharp" => node.child_by_field_name("returns"), _ => node.child_by_field_name("return_type"), }; if language == "cpp" { @@ -509,6 +571,19 @@ fn extract_definitions(language: &str, source: &str) -> Result, } } } + if language == "csharp" { + let type_specifier_node = csharp_find_parent_type_node(&node); + let type_identifier_node = + type_specifier_node.and_then(|n| n.child_by_field_name("name")); + + if let Some(type_identifier_node) = type_identifier_node { + let type_identifier_text = + type_identifier_node.utf8_text(source.as_bytes()).unwrap(); + if name == type_identifier_text { + return_type_node = Some(type_identifier_node); + } + } + } if return_type_node.is_none() { return_type_node = node.child_by_field_name("result"); } @@ -542,6 +617,12 @@ fn extract_definitions(language: &str, source: &str) -> Result, .and_then(|n| n.utf8_text(source.as_bytes()).ok()) .unwrap_or("") .to_string() + } else if language == "csharp" { + csharp_find_parent_type_node(&node) + .and_then(|n| n.child_by_field_name("name")) + .and_then(|n| n.utf8_text(source.as_bytes()).ok()) + .unwrap_or("") + .to_string() } else if let Some(impl_item) = impl_item_node { let impl_type_node = impl_item.child_by_field_name("type"); impl_type_node @@ -645,6 +726,21 @@ fn extract_definitions(language: &str, source: &str) -> Result, .unwrap_or("") .to_string(); } + + if language == "csharp" { + let csharp_visibility = find_descendant_by_type(&node, "modifier"); + if csharp_visibility.is_none() { + continue; + } + let csharp_visibility_text = csharp_visibility + .unwrap() + .utf8_text(source.as_bytes()) + .unwrap(); + if csharp_visibility_text == "private" { + continue; + } + } + if language == "zig" { class_name = zig_find_parent_variable_declaration_name(&node, source.as_bytes()) @@ -1507,6 +1603,60 @@ mod tests { assert_eq!(stringified, expected); } + #[test] + fn test_csharp() { + let source = r#" + using System; + + namespace TestNamespace; + + public class TestClass(TestDependency m) + { + + private int PrivateTestProperty { get; set; } + + private int _privateTestField; + + public int TestProperty { get; set; } + + public string TestField; + + public TestClass() + { + TestProperty = 0; + } + + + public void TestMethod(int a, int b) + { + var innerVarInMethod = 1; + return a + b; + } + + public int TestMethod(int a, int b, int c) => a + b + c; + + private void PrivateMethod() + { + return; + } + + public class MyInnerClass(InnerClassDependency m) {} + + public record MyInnerRecord(int a); + } + + public record TestRecord(int a, int b); + + public enum TestEnum { Value1, Value2 } + "#; + + let definitions = extract_definitions("csharp", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "class MyInnerClass{func MyInnerClass(InnerClassDependency m) -> MyInnerClass;};class MyInnerRecord{func MyInnerRecord(int a) -> MyInnerRecord;};class TestClass{func TestClass(TestDependency m) -> TestClass;func TestClass() -> TestClass;func TestMethod(int a, int b) -> void;func TestMethod(int a, int b, int c) -> int;var TestProperty:int;var TestField:string;};class TestRecord{func TestRecord(int a, int b) -> TestRecord;};enum TestEnum{Value1;Value2;};"; + assert_eq!(stringified, expected); + } + #[test] fn test_unsupported_language() { let source = "print('Hello, world!')"; diff --git a/lua/avante/repo_map.lua b/lua/avante/repo_map.lua index 9a1e6b55c..ad403c55c 100644 --- a/lua/avante/repo_map.lua +++ b/lua/avante/repo_map.lua @@ -6,6 +6,7 @@ local Config = require("avante.config") local filetype_map = { ["javascriptreact"] = "javascript", ["typescriptreact"] = "typescript", + ["cs"] = "csharp", } ---@class AvanteRepoMap