Skip to content

Commit

Permalink
Issue #32 - escape table and field names
Browse files Browse the repository at this point in the history
Quite often, there is a need to create dynamic queries where the table or column name is only known
at run time. Until now, one had to resort to the potentially dangerous | sqlsafe filter and had to
ensure that the table / column name did not have any sql injection.

Most databases provide a way to quote identifiers. Most databases uses double quotes as a way to
quote table / column names. Notable exception is MySql, which by default uses backticks as the escape character

With this commit, we add a new jinja2 filter call identifier. This filter will automatically quote and escape
the table/column names that are injected at run time.

Typical usage:
template = 'SELECT {{colname|identifier}} FROM {{tablename|identifier}}'

will generate a query like
'SELECT somecol FROM myschema.sometable
  • Loading branch information
sripathikrishnan committed Dec 29, 2021
1 parent fb58b0e commit d7fdc7a
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 39 deletions.
27 changes: 25 additions & 2 deletions jinjasql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jinja2.ext import Extension
from jinja2.lexer import Token
from jinja2.utils import Markup
from collections.abc import Iterable

try:
from collections import OrderedDict
Expand Down Expand Up @@ -136,6 +137,23 @@ def _bind_param(already_bound, key, value):
else:
raise AssertionError("Invalid param_style - %s" % param_style)

def build_escape_identifier_filter(identifier_quote_character):
def quote_and_escape(value):
# Escape double quote with 2 double quotes,
# or escape backtick with 2 backticks
return identifier_quote_character + \
value.replace(identifier_quote_character, identifier_quote_character*2) + \
identifier_quote_character

def identifier_filter(raw_identifier):
if isinstance(raw_identifier, str):
raw_identifier = (raw_identifier, )
if not isinstance(raw_identifier, Iterable):
raise ValueError("identifier filter expects a string or an Iterable")
return Markup('.'.join(quote_and_escape(s) for s in raw_identifier))

return identifier_filter

def requires_in_clause(obj):
return isinstance(obj, (list, tuple))

Expand All @@ -151,10 +169,14 @@ class JinjaSql(object):
# pyformat "where name = %(name)s"
# asyncpg "where name = $1"
VALID_PARAM_STYLES = ('qmark', 'numeric', 'named', 'format', 'pyformat', 'asyncpg')
def __init__(self, env=None, param_style='format'):
VALID_ID_QUOTE_CHARS = ('`', '"')
def __init__(self, env=None, param_style='format', identifier_quote_character='"'):
self.param_style = param_style
if identifier_quote_character not in self.VALID_ID_QUOTE_CHARS:
raise ValueError("identifier_quote_characters must be one of " + VALID_ID_QUOTE_CHARS)
self.identifier_quote_character = identifier_quote_character
self.env = env or Environment()
self._prepare_environment()
self.param_style = param_style

def _prepare_environment(self):
self.env.autoescape=True
Expand All @@ -163,6 +185,7 @@ def _prepare_environment(self):
self.env.filters["bind"] = bind
self.env.filters["sqlsafe"] = sql_safe
self.env.filters["inclause"] = bind_in_clause
self.env.filters["identifier"] = build_escape_identifier_filter(self.identifier_quote_character)

def prepare_query(self, source, data):
if isinstance(source, Template):
Expand Down
3 changes: 2 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import sys
import unittest
from tests.test_jinjasql import JinjaSqlTest
from tests.test_postgres import PostgresTest
from tests.test_real_database import PostgresTest, MySqlTest

def all_tests():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(JinjaSqlTest))
suite.addTest(unittest.makeSuite(PostgresTest))
suite.addTest(unittest.makeSuite(MySqlTest))

if sys.version_info <= (3, 4):
from tests.test_django import DjangoTest
Expand Down
27 changes: 27 additions & 0 deletions tests/test_jinjasql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,33 @@ def test_large_inclause(self):
self.assertEqual(len(bind_params), num_of_params)
self.assertEqual(query, "SELECT 'x' WHERE 'A' in (" + "%s," * (num_of_params - 1) + "%s)")

def test_identifier_filter(self):
j = JinjaSql()
template = 'select * from {{table_name | identifier}}'

tests = [
('users', 'select * from "users"'),
(('myschema', 'users'), 'select * from "myschema"."users"'),
('a"b', 'select * from "a""b"'),
(('users',), 'select * from "users"'),
]
for test in tests:
query, _ = j.prepare_query(template, {'table_name': test[0]})
self.assertEqual(query, test[1])


def test_identifier_filter_backtick(self):
j = JinjaSql(identifier_quote_character='`')
template = 'select * from {{table_name | identifier}}'

tests = [
('users', 'select * from `users`'),
(('myschema', 'users'), 'select * from `myschema`.`users`'),
('a`b', 'select * from `a``b`'),
]
for test in tests:
query, _ = j.prepare_query(template, {'table_name': test[0]})
self.assertEqual(query, test[1])

def generate_yaml_tests():
file_path = join(YAML_TESTS_ROOT, "macros.yaml")
Expand Down
36 changes: 0 additions & 36 deletions tests/test_postgres.py

This file was deleted.

68 changes: 68 additions & 0 deletions tests/test_real_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from testcontainers.postgres import PostgresContainer
from testcontainers.mysql import MySqlContainer
import sqlalchemy
import unittest
from jinjasql import JinjaSql

class PostgresTest(unittest.TestCase):

# Core idea inspired from
# https://stackoverflow.com/questions/8416208/in-python-is-there-a-good-idiom-for-using-context-managers-in-setup-teardown
#
# Override the run method to automatically
# a. launch a postgres docker container
# b. create a sqlalchemy connection
# c. at the end of the test, kill the docker container
def run(self, result=None):
with PostgresContainer("postgres:9.5") as postgres:
self.engine = sqlalchemy.create_engine(postgres.get_connection_url())
super(PostgresTest, self).run(result)

def test_bind_array(self):
'It should be possible to bind arrays in a query'
j = JinjaSql()
data = {
"some_num": 1,
"some_array": [1,2,3]
}
template = """
SELECT {{some_num}} = ANY({{some_array}})
"""
query, params = j.prepare_query(template, data)
result = self.engine.execute(query, params).fetchone()
self.assertTrue(result[0])

def test_quoted_tables(self):
j = JinjaSql()
data = {
"all_tables": ("information_schema", "tables")
}
template = """
select table_name from {{all_tables|identifier}}
where table_name = 'pg_user'
"""
query, params = j.prepare_query(template, data)
result = self.engine.execute(query, params).fetchall()
self.assertEqual(len(result), 1)

class MySqlTest(unittest.TestCase):
def run(self, result=None):
with MySqlContainer("mysql:5.7.17") as mysql:
self.engine = sqlalchemy.create_engine(mysql.get_connection_url())
super(MySqlTest, self).run(result)

def test_quoted_tables(self):
j = JinjaSql(identifier_quote_character='`')
data = {
"all_tables": ("information_schema", "tables")
}
template = """
select table_name from {{all_tables|identifier}}
where table_name = 'SESSION_STATUS'
"""
query, params = j.prepare_query(template, data)
result = self.engine.execute(query, params).fetchall()
self.assertEqual(len(result), 1)

if __name__ == '__main__':
unittest.main()

0 comments on commit d7fdc7a

Please sign in to comment.