Skip to content

Commit

Permalink
新增对doris的支持 (#2536)
Browse files Browse the repository at this point in the history
* 新增对doris的支持

支持doris的查询、上线审核
相关讨论:#2175

* 调整继承关系

改为从MysqlEngine类继承

* 删去重复方法

* 用black处理

* 删除重复函数

* reuse get_all_databases

---------

Co-authored-by: grainyu <[email protected]>
Co-authored-by: Leo Q <[email protected]>
  • Loading branch information
3 people authored Mar 17, 2024
1 parent 033964c commit 2d272f2
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 6 deletions.
2 changes: 2 additions & 0 deletions archery/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"phoenix",
"odps",
"cassandra",
"doris",
],
),
ENABLED_NOTIFIERS=(
Expand Down Expand Up @@ -99,6 +100,7 @@
"mongo": {"path": "sql.engines.mongo:MongoEngine"},
"phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"},
"odps": {"path": "sql.engines.odps:ODPSEngine"},
"doris": {"path": "sql.engines.doris:DorisEngine"},
}

ENABLED_NOTIFIERS = env("ENABLED_NOTIFIERS")
Expand Down
188 changes: 188 additions & 0 deletions sql/engines/doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# -*- coding: UTF-8 -*-
from sql.utils.sql_utils import get_syntax_type, remove_comments
from sql.engines.mysql import MysqlEngine
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
from common.config import SysConfig
from MySQLdb.constants import FIELD_TYPE
import traceback
import MySQLdb
import pymysql
import sqlparse
import logging
import re


logger = logging.getLogger("default")


class DorisEngine(MysqlEngine):
name = "Doris"
info = "Doris engine"

auto_backup = False

@property
def server_version(self):
sql = "show frontends"
result = self.query(sql=sql)
version = result.rows[0][-1].split("-")[0]
return tuple([int(n) for n in version.split(".")[:3]])

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description

result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = rows
result_set.affected_rows = len(rows)
except Exception as e:
logger.warning(f"Doris语句执行报错,语句:{sql},错误信息{e}")
result_set.error = str(e).split("Stack trace")[0]
finally:
if close_conn:
self.close()
return result_set

forbidden_databases = [
"__internal_schema",
"INFORMATION_SCHEMA",
"information_schema",
]

def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 禁用/高危语句检查
line = 1
critical_ddl_regex = self.config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sqlparse.split(sql):
statement = sqlparse.format(statement, strip_comments=True)
# 禁用语句
if re.match(r"^select|^show|^explain", statement.lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
sql=statement,
)
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回高危SQL",
errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
sql=statement,
)
# 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1
elif re.match(
r"^update((?!where).)*$|^delete((?!where).)*$", statement.lower()
):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回未带where数据修改",
errormessage="数据修改需带where条件!",
sql=statement,
)
# 正常语句
else:
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
# 判断工单类型
if get_syntax_type(statement) == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
# 统计警告和错误数量
for r in check_result.rows:
if r.errlevel == 1:
check_result.warning_count += 1
if r.errlevel == 2:
check_result.error_count += 1
return check_result

def execute_workflow(self, workflow):
return self.execute(
db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
)

def execute(self, db_name=None, sql="", close_conn=True):
"""执行sql语句 返回 Review set"""
execute_result = ReviewSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
rowid = 1
effect_row = 0
sql_list = sqlparse.split(sql)
for statement in sql_list:
try:
cursor = conn.cursor()
with FuncTimer() as t:
effect_row = cursor.execute(statement)
cursor.close()
execute_result.rows.append(
ReviewResult(
id=rowid,
errlevel=0,
stagestatus="Execute Successfully",
errormessage="None",
sql=statement,
affected_rows=effect_row,
execute_time=t.cost,
)
)
except Exception as e:
logger.warning(
f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}"
)
execute_result.error = str(e)
execute_result.rows.append(
ReviewResult(
id=rowid,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"异常信息:{e}",
sql=statement,
affected_rows=effect_row,
execute_time=t.cost,
)
)
break
rowid += 1
if execute_result.error:
for statement in sql_list[rowid:]:
execute_result.rows.append(
ReviewResult(
id=rowid + 1,
errlevel=2,
stagestatus="Audit Completed",
errormessage="前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
rowid += 1
if close_conn:
self.close()
return execute_result
18 changes: 13 additions & 5 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,32 @@ def kill_connection(self, thread_id):
"""终止数据库连接"""
self.query(sql=f"kill {thread_id}")

# 禁止查询的数据库
forbidden_databases = [
"information_schema",
"performance_schema",
"mysql",
"test",
"sys",
]

def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "show databases"
result = self.query(sql=sql)
db_list = [
row[0]
for row in result.rows
if row[0]
not in ("information_schema", "performance_schema", "mysql", "test", "sys")
row[0] for row in result.rows if row[0] not in self.forbidden_databases
]
result.rows = db_list
return result

forbidden_tables = ["test"]

def get_all_tables(self, db_name, **kwargs):
"""获取table 列表, 返回一个ResultSet"""
sql = "show tables"
result = self.query(db_name=db_name, sql=sql)
tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
tb_list = [row[0] for row in result.rows if row[0] not in self.forbidden_tables]
result.rows = tb_list
return result

Expand Down
47 changes: 47 additions & 0 deletions sql/engines/test_doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pytest_mock import MockFixture

from sql.engines.doris import DorisEngine
from sql.engines.models import ResultSet


def test_doris_server_info(db_instance, mocker: MockFixture):
mock_query = mocker.patch.object(DorisEngine, "query")
mock_query.return_value = ResultSet(
full_sql="show frontends", rows=[["foo", "bar", "2.1.0-doris"]]
)
db_instance.db_type = "doris"
engine = DorisEngine(instance=db_instance)
version = engine.server_version
assert version == (2, 1, 0)


def test_doris_query(db_instance, mocker: MockFixture):
mock_get_connection = mocker.patch.object(DorisEngine, "get_connection")

class DummyCursor:
def __init__(self):
self.description = [("foo",), ("bar",)]
self.fetchall = lambda: [("baz", "qux")]

def execute(self, sql):
pass

mock_get_connection.return_value.cursor.return_value = DummyCursor()
db_instance.db_type = "doris"
engine = DorisEngine(instance=db_instance)
result_set = engine.query(sql="select * from foo")
assert result_set.column_list == ["foo", "bar"]
assert result_set.rows == [("baz", "qux")]
assert result_set.affected_rows == 1


def test_forbidden_db(db_instance, mocker: MockFixture):
db_instance.db_type = "doris"
mock_query = mocker.patch.object(DorisEngine, "query")
mock_query.return_value = ResultSet(
full_sql="show databases", rows=[["__internal_schema"]]
)

engine = DorisEngine(instance=db_instance)
all_db = engine.get_all_databases()
assert all_db.rows == []
1 change: 1 addition & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class Meta:
("clickhouse", "ClickHouse"),
("goinception", "goInception"),
("cassandra", "Cassandra"),
("doris", "Doris"),
)


Expand Down
5 changes: 5 additions & 0 deletions sql/templates/sqlquery.html
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,11 @@ <h4 class="modal-title text-danger">收藏语句</h4>
if (sql === 'explain') {
sqlContent = 'explain ' + sqlContent
}
} else if (optgroup === "Doris") {
//查看执行计划
if (sql === 'explain') {
sqlContent = 'explain ' + sqlContent
}
}
//提交请求
$.ajax({
Expand Down
2 changes: 1 addition & 1 deletion sql/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def config(request):
# 获取所有实例标签
instance_tags = InstanceTag.objects.all()
# 支持自动审核的数据库类型
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis"]
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis", "doris"]
# 获取所有配置项
all_config = Config.objects.all().values("item", "value")
sys_config = {}
Expand Down

0 comments on commit 2d272f2

Please sign in to comment.