Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增对doris的支持 #2536

Merged
merged 7 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions archery/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"phoenix",
"odps",
"cassandra",
"doris",
],
),
ENABLED_NOTIFIERS=(
Expand Down Expand Up @@ -99,6 +100,7 @@
"mongo": {"path": "sql.engines.mongo:MongoEngine"},
"phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"},
"odps": {"path": "sql.engines.odps:ODPSEngine"},
"doris": {"path": "sql.engines.doris:DorisEngine"},
}

ENABLED_NOTIFIERS = env("ENABLED_NOTIFIERS")
Expand Down
188 changes: 188 additions & 0 deletions sql/engines/doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# -*- coding: UTF-8 -*-
from sql.utils.sql_utils import get_syntax_type, remove_comments
from sql.engines.mysql import MysqlEngine
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
from common.config import SysConfig
from MySQLdb.constants import FIELD_TYPE
import traceback
import MySQLdb
import pymysql
import sqlparse
import logging
import re


logger = logging.getLogger("default")


class DorisEngine(MysqlEngine):
name = "Doris"
info = "Doris engine"

auto_backup = False

@property
def server_version(self):
sql = "show frontends"
result = self.query(sql=sql)
version = result.rows[0][-1].split("-")[0]
return tuple([int(n) for n in version.split(".")[:3]])

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description

result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = rows
result_set.affected_rows = len(rows)
except Exception as e:
logger.warning(f"Doris语句执行报错,语句:{sql},错误信息{e}")
result_set.error = str(e).split("Stack trace")[0]
finally:
if close_conn:
self.close()
return result_set

forbidden_databases = [
"__internal_schema",
"INFORMATION_SCHEMA",
"information_schema",
]

def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
# 禁用/高危语句检查
line = 1
critical_ddl_regex = self.config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sqlparse.split(sql):
statement = sqlparse.format(statement, strip_comments=True)
# 禁用语句
if re.match(r"^select|^show|^explain", statement.lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
sql=statement,
)
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回高危SQL",
errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
sql=statement,
)
# 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1
elif re.match(
r"^update((?!where).)*$|^delete((?!where).)*$", statement.lower()
):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回未带where数据修改",
errormessage="数据修改需带where条件!",
sql=statement,
)
# 正常语句
else:
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
# 判断工单类型
if get_syntax_type(statement) == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
# 统计警告和错误数量
for r in check_result.rows:
if r.errlevel == 1:
check_result.warning_count += 1
if r.errlevel == 2:
check_result.error_count += 1
return check_result

def execute_workflow(self, workflow):
return self.execute(
db_name=workflow.db_name, sql=workflow.sqlworkflowcontent.sql_content
)

def execute(self, db_name=None, sql="", close_conn=True):
"""执行sql语句 返回 Review set"""
execute_result = ReviewSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
rowid = 1
effect_row = 0
sql_list = sqlparse.split(sql)
for statement in sql_list:
try:
cursor = conn.cursor()
with FuncTimer() as t:
effect_row = cursor.execute(statement)
cursor.close()
execute_result.rows.append(
ReviewResult(
id=rowid,
errlevel=0,
stagestatus="Execute Successfully",
errormessage="None",
sql=statement,
affected_rows=effect_row,
execute_time=t.cost,
)
)
except Exception as e:
logger.warning(
f"{self.name} 命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}"
)
execute_result.error = str(e)
execute_result.rows.append(
ReviewResult(
id=rowid,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"异常信息:{e}",
sql=statement,
affected_rows=effect_row,
execute_time=t.cost,
)
)
break
rowid += 1
if execute_result.error:
for statement in sql_list[rowid:]:
execute_result.rows.append(
ReviewResult(
id=rowid + 1,
errlevel=2,
stagestatus="Audit Completed",
errormessage="前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
rowid += 1
if close_conn:
self.close()
return execute_result
18 changes: 13 additions & 5 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,32 @@ def kill_connection(self, thread_id):
"""终止数据库连接"""
self.query(sql=f"kill {thread_id}")

# 禁止查询的数据库
forbidden_databases = [
"information_schema",
"performance_schema",
"mysql",
"test",
"sys",
]

def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "show databases"
result = self.query(sql=sql)
db_list = [
row[0]
for row in result.rows
if row[0]
not in ("information_schema", "performance_schema", "mysql", "test", "sys")
row[0] for row in result.rows if row[0] not in self.forbidden_databases
]
result.rows = db_list
return result

forbidden_tables = ["test"]

def get_all_tables(self, db_name, **kwargs):
"""获取table 列表, 返回一个ResultSet"""
sql = "show tables"
result = self.query(db_name=db_name, sql=sql)
tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
tb_list = [row[0] for row in result.rows if row[0] not in self.forbidden_tables]
result.rows = tb_list
return result

Expand Down
47 changes: 47 additions & 0 deletions sql/engines/test_doris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pytest_mock import MockFixture

from sql.engines.doris import DorisEngine
from sql.engines.models import ResultSet


def test_doris_server_info(db_instance, mocker: MockFixture):
mock_query = mocker.patch.object(DorisEngine, "query")
mock_query.return_value = ResultSet(
full_sql="show frontends", rows=[["foo", "bar", "2.1.0-doris"]]
)
db_instance.db_type = "doris"
engine = DorisEngine(instance=db_instance)
version = engine.server_version
assert version == (2, 1, 0)


def test_doris_query(db_instance, mocker: MockFixture):
mock_get_connection = mocker.patch.object(DorisEngine, "get_connection")

class DummyCursor:
def __init__(self):
self.description = [("foo",), ("bar",)]
self.fetchall = lambda: [("baz", "qux")]

def execute(self, sql):
pass

mock_get_connection.return_value.cursor.return_value = DummyCursor()
db_instance.db_type = "doris"
engine = DorisEngine(instance=db_instance)
result_set = engine.query(sql="select * from foo")
assert result_set.column_list == ["foo", "bar"]
assert result_set.rows == [("baz", "qux")]
assert result_set.affected_rows == 1


def test_forbidden_db(db_instance, mocker: MockFixture):
db_instance.db_type = "doris"
mock_query = mocker.patch.object(DorisEngine, "query")
mock_query.return_value = ResultSet(
full_sql="show databases", rows=[["__internal_schema"]]
)

engine = DorisEngine(instance=db_instance)
all_db = engine.get_all_databases()
assert all_db.rows == []
1 change: 1 addition & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class Meta:
("clickhouse", "ClickHouse"),
("goinception", "goInception"),
("cassandra", "Cassandra"),
("doris", "Doris"),
)


Expand Down
5 changes: 5 additions & 0 deletions sql/templates/sqlquery.html
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,11 @@ <h4 class="modal-title text-danger">收藏语句</h4>
if (sql === 'explain') {
sqlContent = 'explain ' + sqlContent
}
} else if (optgroup === "Doris") {
//查看执行计划
if (sql === 'explain') {
sqlContent = 'explain ' + sqlContent
}
}
//提交请求
$.ajax({
Expand Down
2 changes: 1 addition & 1 deletion sql/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def config(request):
# 获取所有实例标签
instance_tags = InstanceTag.objects.all()
# 支持自动审核的数据库类型
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis"]
db_type = ["mysql", "oracle", "mongo", "clickhouse", "redis", "doris"]
# 获取所有配置项
all_config = Config.objects.all().values("item", "value")
sys_config = {}
Expand Down
Loading