Skip to content

Commit

Permalink
feat: Add import path and tree refactor (#48)
Browse files Browse the repository at this point in the history
Rename some confusing tree functions and add a tree method to get all
import paths of a file.
  • Loading branch information
coder3101 authored Jan 12, 2025
1 parent 7709f3c commit f5cda97
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 86 deletions.
20 changes: 0 additions & 20 deletions src/lsp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,6 @@ impl LanguageServer for ProtoLanguageServer {
},
}];

let worktoken = params.work_done_progress_params.work_done_token;
let (tx, rx) = mpsc::channel();
let mut socket = self.client.clone();

thread::spawn(move || {
let Some(token) = worktoken else {
return;
};

while let Ok(value) = rx.recv() {
if let Err(e) = socket.progress(ProgressParams {
token: token.clone(),
value,
}) {
error!(error=%e, "failed to report parse progress");
}
}
});

let file_registration_option = FileOperationRegistrationOptions {
filters: file_operation_filers.clone(),
};
Expand All @@ -80,7 +61,6 @@ impl LanguageServer for ProtoLanguageServer {
for workspace in folders {
info!("Workspace folder: {workspace:?}");
self.configs.add_workspace(&workspace);
self.state.add_workspace_folder_async(workspace, tx.clone());
}
workspace_capabilities = Some(WorkspaceServerCapabilities {
workspace_folders: Some(WorkspaceFoldersServerCapabilities {
Expand Down
6 changes: 6 additions & 0 deletions src/nodekind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub enum NodeKind {
ServiceName,
RpcName,
PackageName,
PackageImport,
}

#[allow(unused)]
Expand All @@ -26,6 +27,7 @@ impl NodeKind {
NodeKind::ServiceName => "service_name",
NodeKind::RpcName => "rpc_name",
NodeKind::PackageName => "full_ident",
NodeKind::PackageImport => "import",
}
}

Expand All @@ -37,6 +39,10 @@ impl NodeKind {
n.kind() == Self::Error.as_str()
}

pub fn is_import_path(n: &Node) -> bool {
n.kind() == Self::PackageImport.as_str()
}

pub fn is_package_name(n: &Node) -> bool {
n.kind() == Self::PackageName.as_str()
}
Expand Down
4 changes: 2 additions & 2 deletions src/parser/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl ParsedTree {
match identifier.split_once('.') {
Some((parent_identifier, remaining)) => {
let child_node = self
.filter_nodes_from(n, NodeKind::is_userdefined)
.find_all_nodes_from(n, NodeKind::is_userdefined)
.into_iter()
.find(|n| {
n.utf8_text(content.as_ref()).expect("utf8-parse error")
Expand All @@ -39,7 +39,7 @@ impl ParsedTree {
}
None => {
let locations: Vec<Location> = self
.filter_nodes_from(n, NodeKind::is_userdefined)
.find_all_nodes_from(n, NodeKind::is_userdefined)
.into_iter()
.filter(|n| {
n.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier
Expand Down
2 changes: 1 addition & 1 deletion src/parser/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::ParsedTree;
impl ParsedTree {
pub fn collect_parse_errors(&self) -> PublishDiagnosticsParams {
let diagnostics = self
.filter_nodes(NodeKind::is_error)
.find_all_nodes(NodeKind::is_error)
.into_iter()
.map(|n| Diagnostic {
range: Range {
Expand Down
4 changes: 2 additions & 2 deletions src/parser/hover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl ParsedTree {
match identifier.split_once('.') {
Some((parent, child)) => {
let child_node = self
.filter_nodes_from(n, NodeKind::is_userdefined)
.find_all_nodes_from(n, NodeKind::is_userdefined)
.into_iter()
.find(|n| n.utf8_text(content.as_ref()).expect("utf8-parse error") == parent)
.and_then(|n| n.parent());
Expand All @@ -77,7 +77,7 @@ impl ParsedTree {
}
None => {
let comments: Vec<String> = self
.filter_nodes_from(n, NodeKind::is_userdefined)
.find_all_nodes_from(n, NodeKind::is_userdefined)
.into_iter()
.filter(|n| {
n.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier
Expand Down
3 changes: 3 additions & 0 deletions src/parser/input/test_filter.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ syntax = "proto3";

package com.parser;

import "foo/bar.proto";
import "baz/bar.proto";

message Book {

message Author {
Expand Down
6 changes: 3 additions & 3 deletions src/parser/rename.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl ParsedTree {
content: impl AsRef<[u8]>,
) -> Option<Vec<Node<'a>>> {
n.parent().map(|p| {
self.filter_nodes_from(p, NodeKind::is_field_name)
self.find_all_nodes_from(p, NodeKind::is_field_name)
.into_iter()
.filter(|i| i.utf8_text(content.as_ref()).expect("utf-8 parse error") == identifier)
.collect()
Expand Down Expand Up @@ -114,7 +114,7 @@ impl ParsedTree {
new_identifier: &str,
content: impl AsRef<[u8]>,
) -> Vec<TextEdit> {
self.filter_nodes(NodeKind::is_field_name)
self.find_all_nodes(NodeKind::is_field_name)
.into_iter()
.filter(|n| {
let ntext = n.utf8_text(content.as_ref()).expect("utf-8 parse error");
Expand All @@ -135,7 +135,7 @@ impl ParsedTree {
}

pub fn reference_field(&self, id: &str, content: impl AsRef<[u8]>) -> Vec<Location> {
self.filter_nodes(NodeKind::is_field_name)
self.find_all_nodes(NodeKind::is_field_name)
.into_iter()
.filter(|n| n.utf8_text(content.as_ref()).expect("utf-8 parse error") == id)
.map(|n| Location {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
source: src/parser/tree.rs
expression: imports
---
- foo/bar.proto
- baz/bar.proto
34 changes: 24 additions & 10 deletions src/parser/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{nodekind::NodeKind, utils::lsp_to_ts_point};
use super::ParsedTree;

impl ParsedTree {
pub(super) fn walk_and_collect_filter<'a>(
pub(super) fn walk_and_filter<'a>(
cursor: &mut TreeCursor<'a>,
f: fn(&Node) -> bool,
early: bool,
Expand All @@ -24,7 +24,7 @@ impl ParsedTree {
}

if cursor.goto_first_child() {
v.extend(Self::walk_and_collect_filter(cursor, f, early));
v.extend(Self::walk_and_filter(cursor, f, early));
cursor.goto_parent();
}

Expand Down Expand Up @@ -110,29 +110,41 @@ impl ParsedTree {
self.tree.root_node().descendant_for_point_range(pos, pos)
}

pub fn filter_nodes(&self, f: fn(&Node) -> bool) -> Vec<Node> {
self.filter_nodes_from(self.tree.root_node(), f)
pub fn find_all_nodes(&self, f: fn(&Node) -> bool) -> Vec<Node> {
self.find_all_nodes_from(self.tree.root_node(), f)
}

pub fn filter_nodes_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
pub fn find_all_nodes_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
let mut cursor = n.walk();
Self::walk_and_collect_filter(&mut cursor, f, false)
Self::walk_and_filter(&mut cursor, f, false)
}

pub fn find_node(&self, f: fn(&Node) -> bool) -> Vec<Node> {
pub fn find_first_node(&self, f: fn(&Node) -> bool) -> Vec<Node> {
self.find_node_from(self.tree.root_node(), f)
}

pub fn find_node_from<'a>(&self, n: Node<'a>, f: fn(&Node) -> bool) -> Vec<Node<'a>> {
let mut cursor = n.walk();
Self::walk_and_collect_filter(&mut cursor, f, true)
Self::walk_and_filter(&mut cursor, f, true)
}

pub fn get_package_name<'a>(&self, content: &'a [u8]) -> Option<&'a str> {
self.find_node(NodeKind::is_package_name)
self.find_first_node(NodeKind::is_package_name)
.first()
.map(|n| n.utf8_text(content).expect("utf-8 parse error"))
}
pub fn get_import_path<'a>(&self, content: &'a [u8]) -> Vec<&'a str> {
self.find_all_nodes(NodeKind::is_import_path)
.into_iter()
.filter_map(|n| {
n.child_by_field_name("path").map(|c| {
c.utf8_text(content)
.expect("utf-8 parse error")
.trim_matches('"')
})
})
.collect()
}
}

#[cfg(test)]
Expand All @@ -150,7 +162,7 @@ mod test {

assert!(parsed.is_some());
let tree = parsed.unwrap();
let nodes = tree.filter_nodes(NodeKind::is_message_name);
let nodes = tree.find_all_nodes(NodeKind::is_message_name);

assert_eq!(nodes.len(), 2);

Expand All @@ -163,5 +175,7 @@ mod test {

let package_name = tree.get_package_name(contents.as_ref());
assert_yaml_snapshot!(package_name);
let imports = tree.get_import_path(contents.as_ref());
assert_yaml_snapshot!(imports);
}
}
30 changes: 28 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
use async_lsp::{router::Router, ClientSocket};
use std::ops::ControlFlow;
use async_lsp::{
lsp_types::{NumberOrString, ProgressParams, ProgressParamsValue},
router::Router,
ClientSocket, LanguageClient,
};
use std::{
ops::ControlFlow,
sync::{mpsc, mpsc::Sender},
thread,
};

use crate::{config::workspace::WorkspaceProtoConfigs, state::ProtoLanguageState};

Expand Down Expand Up @@ -27,4 +35,22 @@ impl ProtoLanguageServer {
self.counter += 1;
ControlFlow::Continue(())
}

fn with_report_progress(&self, token: NumberOrString) -> Sender<ProgressParamsValue> {
let (tx, rx) = mpsc::channel();
let mut socket = self.client.clone();

thread::spawn(move || {
while let Ok(value) = rx.recv() {
if let Err(e) = socket.progress(ProgressParams {
token: token.clone(),
value,
}) {
tracing::error!(error=%e, "failed to report parse progress");
}
}
});

tx
}
}
2 changes: 1 addition & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl ProtoLanguageState {
.into_iter()
.fold(vec![], |mut v, tree| {
let content = self.get_content(&tree.uri);
let t = tree.filter_nodes(f).into_iter().map(|n| CompletionItem {
let t = tree.find_all_nodes(f).into_iter().map(|n| CompletionItem {
label: n.utf8_text(content.as_bytes()).unwrap().to_string(),
kind: Some(k),
..Default::default()
Expand Down
Loading

0 comments on commit f5cda97

Please sign in to comment.