Skip to content

Commit

Permalink
refactor: add schema column to the scripts table (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShenJunkun authored Feb 7, 2023
1 parent 5d62e19 commit afac885
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 39 deletions.
15 changes: 11 additions & 4 deletions src/datanode/src/instance/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ use crate::metric;

#[async_trait]
impl ScriptHandler for Instance {
async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> {
async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _timer = timer!(metric::METRIC_HANDLE_SCRIPTS_ELAPSED);
self.script_executor.insert_script(name, script).await
self.script_executor
.insert_script(schema, name, script)
.await
}

async fn execute_script(&self, name: &str) -> servers::error::Result<Output> {
async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result<Output> {
let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED);
self.script_executor.execute_script(name).await
self.script_executor.execute_script(schema, name).await
}
}
17 changes: 13 additions & 4 deletions src/datanode/src/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ mod python {
})
}

pub async fn insert_script(&self, name: &str, script: &str) -> servers::error::Result<()> {
pub async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _s = self
.script_manager
.insert_and_compile(name, script)
.insert_and_compile(schema, name, script)
.await
.map_err(|e| {
error!(e; "Instance failed to insert script");
Expand All @@ -85,9 +90,13 @@ mod python {
Ok(())
}

pub async fn execute_script(&self, name: &str) -> servers::error::Result<Output> {
pub async fn execute_script(
&self,
schema: &str,
name: &str,
) -> servers::error::Result<Output> {
self.script_manager
.execute(name)
.execute(schema, name)
.await
.map_err(|e| {
error!(e; "Instance failed to execute script");
Expand Down
13 changes: 9 additions & 4 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,14 @@ impl SqlQueryHandler for Instance {

#[async_trait]
impl ScriptHandler for Instance {
async fn insert_script(&self, name: &str, script: &str) -> server_error::Result<()> {
async fn insert_script(
&self,
schema: &str,
name: &str,
script: &str,
) -> server_error::Result<()> {
if let Some(handler) = &self.script_handler {
handler.insert_script(name, script).await
handler.insert_script(schema, name, script).await
} else {
server_error::NotSupportedSnafu {
feat: "Script execution in Frontend",
Expand All @@ -512,9 +517,9 @@ impl ScriptHandler for Instance {
}
}

async fn execute_script(&self, script: &str) -> server_error::Result<Output> {
async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result<Output> {
if let Some(handler) = &self.script_handler {
handler.execute_script(script).await
handler.execute_script(schema, script).await
} else {
server_error::NotSupportedSnafu {
feat: "Script execution in Frontend",
Expand Down
25 changes: 18 additions & 7 deletions src/script/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,25 @@ impl ScriptManager {
Ok(script)
}

pub async fn insert_and_compile(&self, name: &str, script: &str) -> Result<Arc<PyScript>> {
pub async fn insert_and_compile(
&self,
schema: &str,
name: &str,
script: &str,
) -> Result<Arc<PyScript>> {
let compiled_script = self.compile(name, script).await?;
self.table.insert(name, script).await?;
self.table.insert(schema, name, script).await?;
Ok(compiled_script)
}

pub async fn execute(&self, name: &str) -> Result<Output> {
pub async fn execute(&self, schema: &str, name: &str) -> Result<Output> {
let script = {
let s = self.compiled.read().unwrap().get(name).cloned();

if s.is_some() {
s
} else {
self.try_find_script_and_compile(name).await?
self.try_find_script_and_compile(schema, name).await?
}
};

Expand All @@ -90,8 +95,12 @@ impl ScriptManager {
.context(ExecutePythonSnafu { name })
}

async fn try_find_script_and_compile(&self, name: &str) -> Result<Option<Arc<PyScript>>> {
let script = self.table.find_script_by_name(name).await?;
async fn try_find_script_and_compile(
&self,
schema: &str,
name: &str,
) -> Result<Option<Arc<PyScript>>> {
let script = self.table.find_script_by_name(schema, name).await?;

Ok(Some(self.compile(name, &script).await?))
}
Expand Down Expand Up @@ -149,9 +158,11 @@ mod tests {
.unwrap();
catalog_manager.start().await.unwrap();

let schema = "schema";
let name = "test";
mgr.table
.insert(
schema,
name,
r#"
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
Expand All @@ -168,7 +179,7 @@ def test(n):
}

// try to find and compile
let script = mgr.try_find_script_and_compile(name).await.unwrap();
let script = mgr.try_find_script_and_compile(schema, name).await.unwrap();
assert!(script.is_some());

{
Expand Down
28 changes: 21 additions & 7 deletions src/script/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ impl ScriptsTable {
desc: Some("Scripts table".to_string()),
schema,
region_numbers: vec![0],
// name as primary key
primary_key_indices: vec![0],
//schema and name as primary key
primary_key_indices: vec![0, 1],
create_if_not_exists: true,
table_options: HashMap::default(),
};
Expand All @@ -86,8 +86,12 @@ impl ScriptsTable {
})
}

pub async fn insert(&self, name: &str, script: &str) -> Result<()> {
let mut columns_values: HashMap<String, VectorRef> = HashMap::with_capacity(7);
pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> {
let mut columns_values: HashMap<String, VectorRef> = HashMap::with_capacity(8);
columns_values.insert(
"schema".to_string(),
Arc::new(StringVector::from(vec![schema])) as _,
);
columns_values.insert(
"name".to_string(),
Arc::new(StringVector::from(vec![name])) as _,
Expand Down Expand Up @@ -115,7 +119,6 @@ impl ScriptsTable {
"gmt_modified".to_string(),
Arc::new(TimestampMillisecondVector::from_slice(&[now])) as _,
);

let table = self
.catalog_manager
.table(
Expand All @@ -142,12 +145,18 @@ impl ScriptsTable {
Ok(())
}

pub async fn find_script_by_name(&self, name: &str) -> Result<String> {
pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result<String> {
// FIXME(dennis): SQL injection
// TODO(dennis): we use sql to find the script, the better way is use a function
// such as `find_record_by_primary_key` in table_engine.
let sql = format!("select script from {} where name='{}'", self.name(), name);
let sql = format!(
"select script from {} where schema='{}' and name='{}'",
self.name(),
schema,
name
);
let stmt = QueryLanguageParser::parse_sql(&sql).unwrap();

let plan = self
.query_engine
.statement_to_plan(stmt, Arc::new(QueryContext::new()))
Expand Down Expand Up @@ -195,6 +204,11 @@ impl ScriptsTable {
/// Build scripts table
fn build_scripts_schema() -> Schema {
let cols = vec![
ColumnSchema::new(
"schema".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"name".to_string(),
ConcreteDataType::string_datatype(),
Expand Down
23 changes: 21 additions & 2 deletions src/servers/src/http/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,26 @@ pub async fn scripts(
RawBody(body): RawBody,
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let schema = params.schema.as_ref();

if schema.is_none() || schema.unwrap().is_empty() {
json_err!("invalid schema")
}

let name = params.name.as_ref();

if name.is_none() || name.unwrap().is_empty() {
json_err!("invalid name");
}

let bytes = unwrap_or_json_err!(hyper::body::to_bytes(body).await);

let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec()));

let body = match script_handler.insert_script(name.unwrap(), &script).await {
let body = match script_handler
.insert_script(schema.unwrap(), name.unwrap(), &script)
.await
{
Ok(()) => JsonResponse::with_output(None),
Err(e) => json_err!(format!("Insert script error: {e}"), e.status_code()),
};
Expand All @@ -73,6 +83,7 @@ pub async fn scripts(

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ScriptQuery {
pub schema: Option<String>,
pub name: Option<String>,
}

Expand All @@ -84,6 +95,12 @@ pub async fn run_script(
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let start = Instant::now();
let schema = params.schema.as_ref();

if schema.is_none() || schema.unwrap().is_empty() {
json_err!("invalid schema")
}

let name = params.name.as_ref();

if name.is_none() || name.unwrap().is_empty() {
Expand All @@ -92,7 +109,9 @@ pub async fn run_script(

// TODO(sunng87): query_context and db name resolution

let output = script_handler.execute_script(name.unwrap()).await;
let output = script_handler
.execute_script(schema.unwrap(), name.unwrap())
.await;
let resp = JsonResponse::from_output(vec![output]).await;

Json(resp.with_execution_time(start.elapsed().as_millis()))
Expand Down
4 changes: 2 additions & 2 deletions src/servers/src/query_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;

#[async_trait]
pub trait ScriptHandler {
async fn insert_script(&self, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, name: &str) -> Result<Output>;
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
}

#[async_trait]
Expand Down
8 changes: 6 additions & 2 deletions src/servers/tests/http/http_handler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test(n):
)
.await;
assert!(!json.success(), "{json:?}");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid name");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");

let body = RawBody(Body::from(script));
let exec = create_script_query();
Expand All @@ -124,12 +124,16 @@ def test(n):

fn create_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
schema: Some("test".to_string()),
name: Some("test".to_string()),
})
}

fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery { name: None })
Query(script_handler::ScriptQuery {
schema: None,
name: None,
})
}

fn create_query() -> Query<http_handler::SqlQuery> {
Expand Down
10 changes: 6 additions & 4 deletions src/servers/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl SqlQueryHandler for DummyInstance {

#[async_trait]
impl ScriptHandler for DummyInstance {
async fn insert_script(&self, name: &str, script: &str) -> Result<()> {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()> {
let script = self
.py_engine
.compile(script, CompileContext::default())
Expand All @@ -115,13 +115,15 @@ impl ScriptHandler for DummyInstance {
self.scripts
.write()
.unwrap()
.insert(name.to_string(), Arc::new(script));
.insert(format!("{schema}_{name}"), Arc::new(script));

Ok(())
}

async fn execute_script(&self, name: &str) -> Result<Output> {
let py_script = self.scripts.read().unwrap().get(name).unwrap().clone();
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
let key = format!("{schema}_{name}");

let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();

Ok(py_script.execute(EvalContext::default()).await.unwrap())
}
Expand Down
4 changes: 3 additions & 1 deletion src/servers/tests/py_script/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ async fn test_insert_py_udf_and_query() -> Result<()> {
def double_that(col)->vector[u32]:
return col*2
"#;
instance.insert_script("double_that", src).await?;
instance
.insert_script("schema_test", "double_that", src)
.await?;
let res = instance
.do_query("select double_that(uint32s) from numbers", query_ctx)
.await
Expand Down
7 changes: 5 additions & 2 deletions tests-integration/tests/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ pub async fn test_scripts_api(store_type: StorageType) {
let client = TestClient::new(app);

let res = client
.post("/v1/scripts?name=test")
.post("/v1/scripts?schema=schema_test&name=test")
.body(
r#"
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
Expand All @@ -334,7 +334,10 @@ def test(n):
assert!(body.output().is_none());

// call script
let res = client.post("/v1/run-script?name=test").send().await;
let res = client
.post("/v1/run-script?schema=schema_test&name=test")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();

Expand Down

0 comments on commit afac885

Please sign in to comment.