Skip to content

Commit

Permalink
fix: MySQL type compatibility (apecloud#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
ddh-5230 authored Nov 25, 2024
1 parent ee071fd commit fe90074
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 32 deletions.
68 changes: 58 additions & 10 deletions backend/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package backend
import (
stdsql "database/sql"
"io"
"math/big"
"reflect"
"strings"

"github.com/apecloud/myduckserver/charset"
Expand All @@ -28,17 +30,23 @@ import (

var _ sql.RowIter = (*SQLRowIter)(nil)

type typeConversion struct {
idx int
kind reflect.Kind
}

// SQLRowIter wraps a standard sql.Rows as a RowIter.
type SQLRowIter struct {
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
intervals []int
nonUTF8 []int
charsets []sql.CharacterSetID
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
intervals []int
nonUTF8 []int
charsets []sql.CharacterSetID
conversions []typeConversion
}

func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
Expand Down Expand Up @@ -72,14 +80,32 @@ func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
}
}

var conversions []typeConversion
for i, c := range columns {
if c.DatabaseTypeName() == "HUGEINT" {
expectedType := schema[i].Type
if ok := types.IsFloat(expectedType); ok {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Float64})
} else {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Int64})
}
}
if c.DatabaseTypeName() == "DOUBLE" || c.DatabaseTypeName() == "FLOAT" {
expectedType := schema[i].Type
if ok := types.IsInteger(expectedType); ok {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Int64})
}
}
}

width := max(len(columns), len(schema))
buf := make([]any, width)
ptrs := make([]any, width)
for i := range buf {
ptrs[i] = &buf[i]
}

return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals, intervals, nonUTF8, charsets}, nil
return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals, intervals, nonUTF8, charsets, conversions}, nil
}

// Next retrieves the next row. It will return io.EOF if it's the last row.
Expand Down Expand Up @@ -115,6 +141,28 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}

// Process type conversions
for _, targetType := range iter.conversions {
idx := targetType.idx
rawValue := iter.buffer[idx]
if targetType.kind == reflect.Float64 {
switch v := rawValue.(type) {
case *big.Int:
iter.buffer[idx], _ = v.Float64()
}
}
if targetType.kind == reflect.Int64 {
switch v := rawValue.(type) {
case float64:
iter.buffer[idx] = int64(v)
case float32:
iter.buffer[idx] = int64(v)
case *big.Int:
iter.buffer[idx] = v.Int64()
}
}
}

// Prune or fill the values to match the schema
width := len(iter.schema) // the desired width
if width == 0 {
Expand Down
23 changes: 1 addition & 22 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,32 +222,21 @@ func TestQueriesSimple(t *testing.T) {

// auto-generated by dev/extract_queries_to_skip.py
waitForFixQueries := []string{
"SELECT_SUM(i),_i_FROM_mytable_GROUP_BY_i_ORDER_BY_1+SUM(i)_ASC",
"SELECT_SUM(i)_as_sum,_i_FROM_mytable_GROUP_BY_i_ORDER_BY_1+SUM(i)_ASC",
"select_sum(10)_from_mytable",
"_Select_x_from_(select_*_from_xy)_sq1_union_all_select_u_from_(select_*_from_uv)_sq2_limit_1_offset_1;",
"select_count(*)_from_mytable_where_s_in_(1,_'first_row');",
"SELECT_count(*),_i,_concat(i,_i),_123,_'abc',_concat('abc',_'def')_FROM_emptytable;",
"SELECT_pk_FROM_one_pk_WHERE_(pk,_123)_IN_(SELECT_count(*)_AS_u,_123_AS_v_FROM_emptytable);",
"SELECT_pk_FROM_one_pk_WHERE_(pk,_123)_IN_(SELECT_count(*)_AS_u,_123_AS_v_FROM_mytable_WHERE_false);",
"SELECT_pk_FROM_one_pk_WHERE_(pk,_123)_NOT_IN_(SELECT_count(*)_AS_u,_123_AS_v_FROM_emptytable);",
"SELECT_pk_FROM_one_pk_WHERE_(pk,_123)_NOT_IN_(SELECT_count(*)_AS_u,_123_AS_v_FROM_mytable_WHERE_false);",
"SELECT_pk_DIV_2,_SUM(c3)_FROM_one_pk_GROUP_BY_1_ORDER_BY_1",
"SELECT_pk_DIV_2,_SUM(c3)_as_sum_FROM_one_pk_GROUP_BY_1_ORDER_BY_1",
"SELECT_pk_DIV_2,_SUM(c3)_+_sum(c3)_as_sum_FROM_one_pk_GROUP_BY_1_ORDER_BY_1",
"SELECT_pk_DIV_2,_SUM(c3)_+_min(c3)_as_sum_and_min_FROM_one_pk_GROUP_BY_1_ORDER_BY_1",
"SELECT_pk_DIV_2,_SUM(`c3`)_+____min(_c3_)_FROM_one_pk_GROUP_BY_1_ORDER_BY_1",
"SELECT_pk1,_SUM(c1)_FROM_two_pk_GROUP_BY_pk1_ORDER_BY_pk1;",
"SELECT_pk1,_SUM(c1)_FROM_two_pk_WHERE_pk1_=_0",
"SELECT_floor(i),_s_FROM_mytable_mt_ORDER_BY_floor(i)_DESC",
"SELECT_floor(i),_avg(char_length(s))_FROM_mytable_mt_group_by_1_ORDER_BY_floor(i)_DESC",
"SELECT_FORMAT(i,_3)_FROM_mytable;",
"SELECT_FORMAT(i,_3,_'da_DK')_FROM_mytable;",
"SELECT_JSON_OVERLAPS(c3,_'{\"a\":_2,_\"d\":_2}')_FROM_jsontable",
"SELECT_JSON_MERGE(c3,_'{\"a\":_1}')_FROM_jsontable",
"SELECT_JSON_MERGE_PRESERVE(c3,_'{\"a\":_1}')_FROM_jsontable",
"select_json_pretty(c3)_from_jsontable",
"SELECT_i,_sum(i)_FROM_mytable_group_by_1_having_avg(i)_>_1_order_by_1",
"SELECT_a.column_0,_mt.s_from_(values_row(1,\"1\"),_row(2,\"2\"),_row(4,\"4\"))_a____left_join_mytable_mt_on_column_0_=_mt.i____order_by_1",
"WITH_mt_(s,i)_as_(select_char_length(s),_sum(i)_FROM_mytable_group_by_1)_SELECT_s,i_FROM_mt_order_by_1",
"select_i+0.0/(lag(i)_over_(order_by_s))_from_mytable_order_by_1;",
Expand All @@ -256,9 +245,6 @@ func TestQueriesSimple(t *testing.T) {
"WITH_mytable_as_(select_*_FROM_mytable)_SELECT_s,i_FROM_mytable;",
"WITH_mytable_as_(select_*_FROM_mytable_where_i_>_2)_SELECT_*_FROM_mytable;",
"WITH_mytable_as_(select_*_FROM_mytable_where_i_>_2)_SELECT_*_FROM_mytable_union_SELECT_*_from_mytable;",
"____WITH_RECURSIVE_included_parts(sub_part,_part,_quantity)_AS_(_____SELECT_sub_part,_part,_quantity_FROM_parts_WHERE_part_=_'pie'______UNION_ALL_____SELECT_p.sub_part,_p.part,_p.quantity_____FROM_included_parts_AS_pr,_parts_AS_p_____WHERE_p.part_=_pr.sub_part____)____SELECT_sub_part,_sum(quantity)_as_total_quantity____FROM_included_parts____GROUP_BY_sub_part",
"____WITH_RECURSIVE_included_parts(sub_part,_part,_quantity)_AS_(_____SELECT_sub_part,_part,_quantity_FROM_parts_WHERE_lower(part)_=_'pie'______UNION_ALL_____SELECT_p.sub_part,_p.part,_p.quantity_____FROM_included_parts_AS_pr,_parts_AS_p_____WHERE_p.part_=_pr.sub_part____)____SELECT_sub_part,_sum(quantity)_as_total_quantity____FROM_included_parts____GROUP_BY_sub_part",
"____WITH_RECURSIVE_included_parts(sub_part,_part,_quantity)_AS_(_____SELECT_sub_part,_part,_quantity_FROM_parts_WHERE_part_=_(select_part_from_parts_where_part_=_'pie'_and_sub_part_=_'crust')______UNION_ALL_____SELECT_p.sub_part,_p.part,_p.quantity_____FROM_included_parts_AS_pr,_parts_AS_p_____WHERE_p.part_=_pr.sub_part____)____SELECT_sub_part,_sum(quantity)_as_total_quantity____FROM_included_parts____GROUP_BY_sub_part",
"SELECT_i,_1_AS_foo,_2_AS_bar_FROM_MyTable_HAVING_bar_=_2_ORDER_BY_foo,_i;",
"SELECT_i,_1_AS_foo,_2_AS_bar_FROM_MyTable_HAVING_bar_=_1_ORDER_BY_foo,_i;",
"SELECT_reservedWordsTable.AND,_reservedWordsTABLE.Or,_reservedwordstable.SEleCT_FROM_reservedWordsTable;",
Expand Down Expand Up @@ -305,6 +291,7 @@ func TestQueriesSimple(t *testing.T) {
"SELECT_i,v_from_stringandtable_WHERE_v_XOR_NOT_v",
"select_pk,_________row_number()_over_(order_by_pk_desc),_________sum(v1)_over_(partition_by_v2_order_by_pk),_________percent_rank()_over(partition_by_v2_order_by_pk)_____from_one_pk_three_idx_order_by_pk",
"select_pk,____________________percent_rank()_over(partition_by_v2_order_by_pk),____________________dense_rank()_over(partition_by_v2_order_by_pk),____________________rank()_over(partition_by_v2_order_by_pk)_____from_one_pk_three_idx_order_by_pk",
"select_pk,_________first_value(pk)_over_(order_by_pk_desc),_________lag(pk,_1)_over_(order_by_pk_desc),_________count(pk)_over(partition_by_v1_order_by_pk),_________max(pk)_over(partition_by_v1_order_by_pk_desc),_________avg(v2)_over_(partition_by_v1_order_by_pk)_____from_one_pk_three_idx_order_by_pk",
"SELECT_CAST(-3_AS_UNSIGNED)_FROM_mytable",
"SELECT_CONVERT(-3,_UNSIGNED)_FROM_mytable",
"SELECT_s_>_2_FROM_tabletest",
Expand All @@ -323,11 +310,6 @@ func TestQueriesSimple(t *testing.T) {
"select_date_format(datetime_col,_'%D')_from_datetime_table_order_by_1",
"select_time_format(time_col,_'%h%p')_from_datetime_table_order_by_1",
"select_from_unixtime(i)_from_mytable_order_by_1",
"SELECT_SUM(i)_+_1,_i_FROM_mytable_GROUP_BY_i_ORDER_BY_i",
"SELECT_SUM(i)_as_sum,_i_FROM_mytable_GROUP_BY_i_ORDER_BY_sum_ASC",
"SELECT_i,_SUM(i)_FROM_mytable_GROUP_BY_i_ORDER_BY_sum(i)_DESC",
"SELECT_i,_SUM(i)_as_b_FROM_mytable_GROUP_BY_i_ORDER_BY_b_DESC",
"SELECT_i,_SUM(i)_as_`sum(i)`_FROM_mytable_GROUP_BY_i_ORDER_BY_sum(i)_DESC",
"SELECT_CASE_WHEN_i_>_2_THEN_i_WHEN_i_<_2_THEN_i_ELSE_'two'_END_FROM_mytable",
"SELECT_CASE_WHEN_i_>_2_THEN_'more_than_two'_WHEN_i_<_2_THEN_'less_than_two'_ELSE_2_END_FROM_mytable",
"SELECT_substring(mytable.s,_1,_5)_AS_s_FROM_mytable_INNER_JOIN_othertable_ON_(substring(mytable.s,_1,_5)_=_SUBSTRING(othertable.s2,_1,_5))_GROUP_BY_1_HAVING_s_=_\"secon\"",
Expand All @@ -351,9 +333,6 @@ func TestQueriesSimple(t *testing.T) {
"SELECT_pk,_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_AS_x_FROM_one_pk_opk_GROUP_BY_x_ORDER_BY_x",
"SELECT_pk,_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_AS_x_______FROM_one_pk_opk_WHERE_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_>_0_______GROUP_BY_x_ORDER_BY_x",
"SELECT_pk,_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_AS_x_______FROM_one_pk_opk_WHERE_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_>_0_______GROUP_BY_(SELECT_max(pk)_FROM_one_pk_WHERE_pk_<_opk.pk)_ORDER_BY_x",
"SELECT_pk,_______(SELECT_sum(pk1+pk2)_FROM_two_pk_WHERE_pk1+pk2_IN_(SELECT_pk1+pk2_FROM_two_pk_WHERE_pk1+pk2_=_pk))_AS_sum,_______(SELECT_min(pk2)_FROM_two_pk_WHERE_pk2_IN_(SELECT_pk2_FROM_two_pk_WHERE_pk2_=_pk))_AS_equal_______FROM_one_pk_ORDER_BY_pk;",
"SELECT_pk,_______(SELECT_sum(c1)_FROM_two_pk_WHERE_c1_+_3_IN_(SELECT_c4_FROM_two_pk_WHERE_c3_>_opk.c5))_AS_sum,_______(SELECT_sum(c1)_FROM_two_pk_WHERE_pk2_IN_(SELECT_pk2_FROM_two_pk_WHERE_c1_+_1_<_opk.c2))_AS_sum2______FROM_one_pk_opk_ORDER_BY_pk",
"SELECT_DISTINCT_n_FROM_bigtable_ORDER_BY_t",
"SELECT_GREATEST(CAST(i_AS_CHAR),_CAST(b_AS_CHAR))_FROM_niltable_order_by_i",
"SELECT_count(*)_FROM_people_WHERE_last_name='doe'_and_first_name='jane'_order_by_dob",
"SELECT_VALUES(i)_FROM_mytable",
Expand Down

0 comments on commit fe90074

Please sign in to comment.