Skip to content

Commit

Permalink
allow spaces in headers/escape SQL (#38)
Browse files Browse the repository at this point in the history
allow spaces in headers and ensure sql injection doesn't happen
  • Loading branch information
hderms authored May 6, 2021
1 parent 4375a4b commit 6b57091
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 8 deletions.
25 changes: 25 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ md5 = "0.7.0"
statistical = "1.0.0"
tree_magic = "0.2.3"
flate2 = "1.0.20"
format-sql-query="0.4.0"

[dev-dependencies]
assert_cmd="0.10"
Expand Down
14 changes: 9 additions & 5 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rusqlite::types::ValueRef;
use rusqlite::{CachedStatement, Connection, Result};

use crate::db::functions::{calculate_md5, calculate_sqrt, Stddev};
use crate::db::utils::repeat_vars;
use crate::db::utils::{escape_fields, escape_table, repeat_vars};

mod functions;
pub mod utils;
Expand Down Expand Up @@ -54,17 +54,20 @@ impl Db {
}

pub fn create_table(&mut self, table_name: &str, fields: &[&str]) -> Result<usize> {
let string = format!("create table {} ({});", table_name, fields.join(", "));
let string = format!(
"create table {} ({});",
escape_table(table_name),
fields.join(", ")
);
self.connection.execute(string.as_str(), [])
}

pub fn insert(&mut self, table_name: &str, fields: &[&str], values: Vec<Vec<&str>>) {
let fields_len = fields.len();
let fields = fields.join(",");
let string = format!(
"INSERT INTO {} ({}) values ({})",
table_name,
fields,
escape_table(table_name),
escape_fields(fields).join(", "),
repeat_vars(fields_len)
);
let mut stmt = self.connection.prepare_cached(string.as_str()).unwrap();
Expand All @@ -82,6 +85,7 @@ impl Db {

pub fn select_statement(&self, query: &str) -> Result<(Header, Rows), Box<dyn Error>> {
debug!("Running select statement: {:?}", query);

let mut statement: CachedStatement = self.connection.prepare_cached(query).unwrap();
let results = statement
.query_map([], move |row| {
Expand Down
42 changes: 39 additions & 3 deletions src/db/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::csv::csv_data::{CsvData, CsvType};
use crate::csv::inference::ColumnInference;
use format_sql_query::{Column, Table};

const INTEGER_STRING: &str = "integer";
const TEXT_STRING: &str = "text";
Expand All @@ -8,16 +9,17 @@ pub fn to_table_parameters(csv_data: &CsvData, column_inference: &ColumnInferenc
let mut vec = Vec::with_capacity(csv_data.headers.len());
for header in csv_data.headers.iter() {
let column_type = column_inference.get_type(header.to_string()).unwrap();
let table_name = escape_table(header);
let string = match column_type {
CsvType::Integer => {
format!("{} {}", header, INTEGER_STRING)
format!("{} {}", table_name, INTEGER_STRING)
}
CsvType::String => {
format!("{} {}", header, TEXT_STRING)
format!("{} {}", table_name, TEXT_STRING)
}

CsvType::Float => {
format!("{} {}", header, FLOAT_STRING)
format!("{} {}", table_name, FLOAT_STRING)
}
};
vec.push(string);
Expand All @@ -39,6 +41,16 @@ pub fn repeat_vars(count: usize) -> String {
s
}

pub fn escape_fields(fields: &[&str]) -> Vec<String> {
fields
.iter()
.map(|&field| format!("{}", Column(field.to_string().as_str().into())))
.collect()
}
pub fn escape_table(table_name: &str) -> String {
format!("{}", Table(table_name.to_string().as_str().into()))
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -61,4 +73,28 @@ mod tests {
fn it_fails_above_1000() {
repeat_vars(1001);
}

#[test]
fn it_escapes_tables() {
assert_eq!(escape_table("foo bar"), String::from("\"foo bar\""));
assert_eq!(
escape_table("bobby\"; drop table foo"),
String::from("\"bobby\"\"; drop table foo\"")
)
}

#[test]
fn it_escapes_fields() {
assert_eq!(
escape_fields(&["foo bar"]),
vec!(String::from("\"foo bar\""))
);
assert_eq!(
escape_fields(&["foo bar", "\"foo; drop table bar;"]),
vec!(
String::from("\"foo bar\""),
String::from("\"\"\"foo; drop table bar;\"")
)
)
}
}
3 changes: 3 additions & 0 deletions testdata/occupations_with_spaces_in_headers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
occupation,minimum age
Bartender,18
Construction Worker,18
3 changes: 3 additions & 0 deletions testdata/sql_injection.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
occupation,minimumage";select load_extension("foo")
Bartender,18
Construction Worker,18
16 changes: 16 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,22 @@ mod query_subcommand {
cmd.assert().success();
Ok(())
}

#[test]
fn it_succeeds_with_a_csv_with_spaces_in_headers() -> Result<(), Box<dyn std::error::Error>> {
let mut cmd = build_cmd();
cmd.arg("select \"minimum age\" from testdata/occupations_with_spaces_in_headers.csv");
cmd.assert().success();
Ok(())
}

#[test]
fn it_succeeds_without_running_sql_in_file_headers() -> Result<(), Box<dyn std::error::Error>> {
let mut cmd = build_cmd();
cmd.arg("select \"minimum age\" from testdata/sql_injection.csv");
cmd.assert().success();
Ok(())
}
}
mod analyze_subcommand {
use std::process::Command;
Expand Down

0 comments on commit 6b57091

Please sign in to comment.