Skip to content

Commit

Permalink
Merge pull request #1735 from pyiron/dbescape
Browse files Browse the repository at this point in the history
[patch] Escape database queries
  • Loading branch information
pmrv authored Jan 20, 2025
2 parents 43b68b7 + 5ef7231 commit dd2839c
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions pyiron_base/database/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,30 @@ def _job_dict(
'username': u'test'},.......]
"""
if not self._sql_lite:

def escape(s, escape_char="\\", special_chars="_%"):
"""Insert escape_char in front of special_chars, unless present.
Handles the cases where s already contains escaped characters,
including the escape character itself.
Defaults for LIKE in SQL statements."""
for c in special_chars:
if c in s:
s = s.replace(escape_char + c, c)
s = s.replace(c, escape_char + c)
return s

else:

def escape(s, escape_char="\\", special_chars="_%"):
return s

dict_clause = {}
# FOR GET_ITEMS_SQL: clause = []
if user is not None:
dict_clause["username"] = str(user)
dict_clause["username"] = escape(str(user))
# FOR GET_ITEMS_SQL: clause.append("username = '" + self.user + "'")
if sql_query is not None:
# FOR GET_ITEMS_SQL: clause.append(self.sql_query)
Expand All @@ -329,18 +349,18 @@ def _job_dict(
{str(element.split()[0]): element.split()[2] for element in cl_split}
)
if job is not None:
dict_clause["job"] = str(job)
dict_clause["job"] = escape(str(job))

if project_path == "./":
project_path = ""
if recursive:
dict_clause["project"] = str(project_path) + "%"
dict_clause["project"] = escape(str(project_path)) + "%"
else:
dict_clause["project"] = str(project_path)
dict_clause["project"] = escape(str(project_path))
if sub_job_name is None:
dict_clause["subjob"] = None
elif sub_job_name != "%":
dict_clause["subjob"] = str(sub_job_name)
dict_clause["subjob"] = escape(str(sub_job_name))
if element_lst is not None:
dict_clause["element_lst"] = element_lst

Expand Down Expand Up @@ -880,10 +900,10 @@ def get_items_dict(
self.conn.connection.create_function("like", 2, self.regexp)

result = self.conn.execute(query)
row = result.fetchall()
results = [row._asdict() for row in result.fetchall()]
if not self._keep_connection:
self.conn.close()
return [dict(zip(col._mapping.keys(), col._mapping.values())) for col in row]
return results

def get_job_status(self, job_id: int) -> Union[str, None]:
try:
Expand Down

0 comments on commit dd2839c

Please sign in to comment.