Skip to content

Commit

Permalink
feat: Add csharp support to repo_map (#1073)
Browse files Browse the repository at this point in the history
  • Loading branch information
kostya9 authored Jan 14, 2025
1 parent 2d5306b commit 143519a
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Build.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ function Build-FromSource($feature) {
$SCRIPT_DIR = $PSScriptRoot
$targetTokenizerFile = "avante_tokenizers.dll"
$targetTemplatesFile = "avante_templates.dll"
$targetRepoMapFile = "avante_repo_map.dll"
Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_tokenizers.dll") (Join-Path $BuildDir $targetTokenizerFile)
Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_templates.dll") (Join-Path $BuildDir $targetTemplatesFile)
Copy-Item (Join-Path $SCRIPT_DIR "target\release\avante_repo_map.dll") (Join-Path $BuildDir $targetRepoMapFile)

Remove-Item -Recurse -Force "target"
}
Expand Down
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/avante-repo-map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ tree-sitter-ruby = "0.23"
tree-sitter-zig = "1.0.2"
tree-sitter-scala = "0.23"
tree-sitter-elixir = "0.3.1"
tree-sitter-c-sharp = "0.23"

[lints]
workspace = true
Expand Down
25 changes: 25 additions & 0 deletions crates/avante-repo-map/queries/tree-sitter-c-sharp-defs.scm
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
(class_declaration
name: (identifier) @class
(parameter_list)? @method) ;; Primary constructor

(record_declaration
name: (identifier) @class
(parameter_list)? @method) ;; Primary constructor

(interface_declaration
name: (identifier) @class)

(method_declaration) @method

(constructor_declaration) @method

(property_declaration) @class_variable

(field_declaration
(variable_declaration
(variable_declarator))) @class_variable

(enum_declaration
body: (enum_member_declaration_list
(enum_member_declaration) @enum_item))

150 changes: 150 additions & 0 deletions crates/avante-repo-map/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
"zig" => Some(tree_sitter_zig::LANGUAGE),
"scala" => Some(tree_sitter_scala::LANGUAGE),
"elixir" => Some(tree_sitter_elixir::LANGUAGE),
"csharp" => Some(tree_sitter_c_sharp::LANGUAGE),
_ => None,
}
}
Expand All @@ -79,6 +80,7 @@ const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-d
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm");
const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm");
const CSHARP_QUERY: &str = include_str!("../queries/tree-sitter-c-sharp-defs.scm");

fn get_definitions_query(language: &str) -> Result<Query, String> {
let ts_language = get_ts_language(language);
Expand All @@ -99,6 +101,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
"ruby" => RUBY_QUERY,
"scala" => SCALA_QUERY,
"elixir" => ELIXIR_QUERY,
"csharp" => CSHARP_QUERY,
_ => return Err(format!("Unsupported language: {language}")),
};
let query = Query::new(&ts_language.into(), contents)
Expand Down Expand Up @@ -129,6 +132,20 @@ fn find_ancestor_by_type<'a>(node: &'a Node, parent_type: &str) -> Option<Node<'
None
}

fn find_first_ancestor_by_types<'a>(
node: &'a Node,
possible_parent_types: &[&str],
) -> Option<Node<'a>> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
if possible_parent_types.contains(&parent_node.kind()) {
return Some(parent_node);
}
parent = parent_node.parent();
}
None
}

fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option<Node<'a>> {
let mut cursor = node.walk();
for i in 0..node.descendant_count() {
Expand Down Expand Up @@ -189,6 +206,17 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<Strin
None
}

fn csharp_is_primary_constructor(node: &Node) -> bool {
node.kind() == "parameter_list"
&& node.parent().map_or(false, |n| {
n.kind() == "class_declaration" || n.kind() == "record_declaration"
})
}

fn csharp_find_parent_type_node<'a>(node: &'a Node) -> Option<Node<'a>> {
find_first_ancestor_by_types(node, &["class_declaration", "record_declaration"])
}

fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8]) -> Option<String> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
Expand Down Expand Up @@ -351,6 +379,22 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text)
.to_string(),
"csharp" => {
let mut identifier = node;
// Handle primary constructors (they are direct children of *_declaration)
if *capture_name == "method" && csharp_is_primary_constructor(&node) {
identifier = node.parent().unwrap_or(node);
} else if *capture_name == "class_variable" {
identifier =
find_descendant_by_type(&node, "variable_declarator").unwrap_or(node);
}

identifier
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
.unwrap_or(node_text)
.to_string()
}
_ => node
.child_by_field_name("name")
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
Expand Down Expand Up @@ -472,6 +516,23 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}

if language == "csharp" {
let csharp_visibility = find_descendant_by_type(&node, "modifier");
if csharp_visibility.is_none() && !csharp_is_primary_constructor(&node) {
continue;
}
if csharp_visibility.is_some() {
let csharp_visibility_text = csharp_visibility
.unwrap()
.utf8_text(source.as_bytes())
.unwrap();
if csharp_visibility_text == "private" {
continue;
}
}
}

let mut params_node = node
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));
Expand All @@ -494,6 +555,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.unwrap_or("()");
let mut return_type_node = match language {
"cpp" => node.child_by_field_name("type"),
"csharp" => node.child_by_field_name("returns"),
_ => node.child_by_field_name("return_type"),
};
if language == "cpp" {
Expand All @@ -509,6 +571,19 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
}
}
}
if language == "csharp" {
let type_specifier_node = csharp_find_parent_type_node(&node);
let type_identifier_node =
type_specifier_node.and_then(|n| n.child_by_field_name("name"));

if let Some(type_identifier_node) = type_identifier_node {
let type_identifier_text =
type_identifier_node.utf8_text(source.as_bytes()).unwrap();
if name == type_identifier_text {
return_type_node = Some(type_identifier_node);
}
}
}
if return_type_node.is_none() {
return_type_node = node.child_by_field_name("result");
}
Expand Down Expand Up @@ -542,6 +617,12 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
.unwrap_or("")
.to_string()
} else if language == "csharp" {
csharp_find_parent_type_node(&node)
.and_then(|n| n.child_by_field_name("name"))
.and_then(|n| n.utf8_text(source.as_bytes()).ok())
.unwrap_or("")
.to_string()
} else if let Some(impl_item) = impl_item_node {
let impl_type_node = impl_item.child_by_field_name("type");
impl_type_node
Expand Down Expand Up @@ -645,6 +726,21 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.unwrap_or("")
.to_string();
}

if language == "csharp" {
let csharp_visibility = find_descendant_by_type(&node, "modifier");
if csharp_visibility.is_none() {
continue;
}
let csharp_visibility_text = csharp_visibility
.unwrap()
.utf8_text(source.as_bytes())
.unwrap();
if csharp_visibility_text == "private" {
continue;
}
}

if language == "zig" {
class_name =
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
Expand Down Expand Up @@ -1507,6 +1603,60 @@ mod tests {
assert_eq!(stringified, expected);
}

#[test]
fn test_csharp() {
let source = r#"
using System;
namespace TestNamespace;
public class TestClass(TestDependency m)
{
private int PrivateTestProperty { get; set; }
private int _privateTestField;
public int TestProperty { get; set; }
public string TestField;
public TestClass()
{
TestProperty = 0;
}
public void TestMethod(int a, int b)
{
var innerVarInMethod = 1;
return a + b;
}
public int TestMethod(int a, int b, int c) => a + b + c;
private void PrivateMethod()
{
return;
}
public class MyInnerClass(InnerClassDependency m) {}
public record MyInnerRecord(int a);
}
public record TestRecord(int a, int b);
public enum TestEnum { Value1, Value2 }
"#;

let definitions = extract_definitions("csharp", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "class MyInnerClass{func MyInnerClass(InnerClassDependency m) -> MyInnerClass;};class MyInnerRecord{func MyInnerRecord(int a) -> MyInnerRecord;};class TestClass{func TestClass(TestDependency m) -> TestClass;func TestClass() -> TestClass;func TestMethod(int a, int b) -> void;func TestMethod(int a, int b, int c) -> int;var TestProperty:int;var TestField:string;};class TestRecord{func TestRecord(int a, int b) -> TestRecord;};enum TestEnum{Value1;Value2;};";
assert_eq!(stringified, expected);
}

#[test]
fn test_unsupported_language() {
let source = "print('Hello, world!')";
Expand Down
1 change: 1 addition & 0 deletions lua/avante/repo_map.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ local Config = require("avante.config")
local filetype_map = {
["javascriptreact"] = "javascript",
["typescriptreact"] = "typescript",
["cs"] = "csharp",
}

---@class AvanteRepoMap
Expand Down

0 comments on commit 143519a

Please sign in to comment.