From 16deddc4096fb013a1a555081c85127121b6018a Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Thu, 28 Jun 2018 12:43:42 +0200 Subject: [PATCH] vendor: upgrade go-mysql-server Signed-off-by: Miguel Molina --- Gopkg.lock | 4 +- Gopkg.toml | 2 +- docs/using-gitbase/functions.md | 2 +- docs/using-gitbase/supported-syntax.md | 2 +- integration_test.go | 4 +- .../src-d/go-mysql-server.v0/SUPPORTED.md | 1 + .../src-d/go-mysql-server.v0/engine_test.go | 126 +++-- .../go-mysql-server.v0/sql/analyzer/rules.go | 336 +++++++++++- .../sql/analyzer/rules_test.go | 400 ++++++++++++++ .../expression/function/aggregation/sum.go | 82 +++ .../function/aggregation/sum_test.go | 75 +++ .../sql/expression/function/registry.go | 3 + .../src-d/go-mysql-server.v0/sql/index.go | 29 +- .../sql/index/pilosa/driver.go | 126 ++++- .../sql/index/pilosa/driver_test.go | 156 +++++- .../sql/index/pilosa/index.go | 509 +++++++++++++++++- .../sql/index/pilosa/index_test.go | 120 +++++ .../sql/index/pilosa/mapping.go | 181 ++++++- .../sql/index/pilosa/mapping_test.go | 4 +- .../go-mysql-server.v0/sql/index_test.go | 18 +- .../go-mysql-server.v0/sql/parse/parse.go | 14 +- .../sql/parse/parse_test.go | 74 +-- .../sql/plan/create_index.go | 155 +++++- .../sql/plan/create_index_test.go | 85 ++- .../go-mysql-server.v0/sql/plan/drop_index.go | 2 +- .../go-mysql-server.v0/sql/plan/group_by.go | 6 + .../src-d/go-mysql-server.v0/sql/plan/sort.go | 2 +- .../go-mysql-server.v0/test/mem_tracer.go | 48 ++ 28 files changed, 2373 insertions(+), 193 deletions(-) create mode 100644 vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum.go create mode 100644 vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum_test.go create mode 100644 vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index_test.go create mode 100644 vendor/gopkg.in/src-d/go-mysql-server.v0/test/mem_tracer.go diff --git a/Gopkg.lock b/Gopkg.lock index ee1bbb5a6..c523638c3 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -449,7 +449,7 @@ "sql/parse", "sql/plan" ] - revision = "0084abf48137b4d72c6f948abfde91a00f3f77f0" + revision = "24a5925748096dd3116b0073833a438ecb0c34f9" [[projects]] name = "gopkg.in/src-d/go-siva.v1" @@ -502,6 +502,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "4bfb6e5854c5ea69987deb46f79302bc899c3828835a0129ee6e6ba9eb0e017d" + inputs-digest = "57fb82c2b8fab7467f2bafad3aecb9e33c5ab0bfd649ccd94dfce4961a52b29c" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 63205cef6..9d05b8633 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -1,6 +1,6 @@ [[constraint]] name = "gopkg.in/src-d/go-mysql-server.v0" - revision = "0084abf48137b4d72c6f948abfde91a00f3f77f0" + revision = "24a5925748096dd3116b0073833a438ecb0c34f9" [[constraint]] name = "github.com/jessevdk/go-flags" diff --git a/docs/using-gitbase/functions.md b/docs/using-gitbase/functions.md index 3488c83b2..1657c3b36 100644 --- a/docs/using-gitbase/functions.md +++ b/docs/using-gitbase/functions.md @@ -14,4 +14,4 @@ To make some common tasks easier for the user, there are some functions to inter ## Standard functions -You can check standard functions in [`go-mysql-server` documentation](https://github.com/src-d/go-mysql-server/tree/5620932d8b3ca58edd6bfa4c168073d4c1ff665f#custom-functions). \ No newline at end of file +You can check standard functions in [`go-mysql-server` documentation](https://github.com/src-d/go-mysql-server/tree/24a5925748096dd3116b0073833a438ecb0c34f9#custom-functions). \ No newline at end of file diff --git a/docs/using-gitbase/supported-syntax.md b/docs/using-gitbase/supported-syntax.md index 1b98e6812..fa2baf350 100644 --- a/docs/using-gitbase/supported-syntax.md +++ b/docs/using-gitbase/supported-syntax.md @@ -1,3 +1,3 @@ ## Supported syntax -To see the SQL subset currently supported take a look at [this list](https://github.com/src-d/go-mysql-server/blob/5620932d8b3ca58edd6bfa4c168073d4c1ff665f/SUPPORTED.md) from [src-d/go-mysql-server](https://github.com/src-d/go-mysql-server). \ No newline at end of file +To see the SQL subset currently supported take a look at [this list](https://github.com/src-d/go-mysql-server/blob/24a5925748096dd3116b0073833a438ecb0c34f9/SUPPORTED.md) from [src-d/go-mysql-server](https://github.com/src-d/go-mysql-server). \ No newline at end of file diff --git a/integration_test.go b/integration_test.go index d1024e4a8..76ae1073f 100644 --- a/integration_test.go +++ b/integration_test.go @@ -737,7 +737,7 @@ func createIndex( iter, err := data.table.(sql.Indexable).IndexKeyValueIter(ctx, data.columns) require.NoError(err) - require.NoError(driver.Save(context.Background(), idx, iter)) + require.NoError(driver.Save(sql.NewEmptyContext(), idx, iter)) done <- struct{}{} } @@ -748,7 +748,7 @@ func deleteIndex( data indexData, ) { t.Helper() - done, err := e.Catalog.DeleteIndex("foo", data.id) + done, err := e.Catalog.DeleteIndex("foo", data.id, true) require.NoError(t, err) <-done } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/SUPPORTED.md b/vendor/gopkg.in/src-d/go-mysql-server.v0/SUPPORTED.md index 32208a248..370243bda 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/SUPPORTED.md +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/SUPPORTED.md @@ -21,6 +21,7 @@ - COUNT - MAX - MIN +- SUM (always returns DOUBLE) ## Standard expressions - ALIAS (AS) diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/engine_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/engine_test.go index 180720502..393ac3324 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/engine_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/engine_test.go @@ -15,9 +15,8 @@ import ( "gopkg.in/src-d/go-mysql-server.v0/sql/expression" "gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa" "gopkg.in/src-d/go-mysql-server.v0/sql/parse" + "gopkg.in/src-d/go-mysql-server.v0/test" - opentracing "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/log" "github.com/stretchr/testify/require" ) @@ -242,6 +241,40 @@ var queries = []struct { sql.NewRow([]interface{}{"third row"}), }, }, + { + `SELECT SUM(i) FROM mytable`, + []sql.Row{{float64(6)}}, + }, + { + `SELECT * FROM mytable mt INNER JOIN othertable ot ON mt.i = ot.i2 AND mt.i > 2`, + []sql.Row{ + {int64(3), "third row", "first", int64(3)}, + }, + }, + { + `SELECT i as foo FROM mytable ORDER BY i DESC`, + []sql.Row{ + {int64(3)}, + {int64(2)}, + {int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY i DESC`, + []sql.Row{ + {int32(1), int64(3)}, + {int32(1), int64(2)}, + {int32(1), int64(1)}, + }, + }, + { + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY foo, i DESC`, + []sql.Row{ + {int32(1), int64(3)}, + {int32(1), int64(2)}, + {int32(1), int64(1)}, + }, + }, } func TestQueries(t *testing.T) { @@ -683,11 +716,11 @@ func TestIndexes(t *testing.T) { iter, err := table.IndexKeyValueIter(sql.NewEmptyContext(), []string{"i"}) require.NoError(err) - require.NoError(driver.Save(context.TODO(), idx, iter)) + require.NoError(driver.Save(sql.NewEmptyContext(), idx, iter)) created <- struct{}{} defer func() { - done, err := e.Catalog.DeleteIndex("foo", "myidx") + done, err := e.Catalog.DeleteIndex("foo", "myidx", true) require.NoError(err) <-done }() @@ -720,7 +753,7 @@ func TestCreateIndex(t *testing.T) { defer func() { time.Sleep(1 * time.Second) - done, err := e.Catalog.DeleteIndex("foo", "myidx") + done, err := e.Catalog.DeleteIndex("foo", "myidx", true) require.NoError(err) <-done @@ -728,11 +761,49 @@ func TestCreateIndex(t *testing.T) { }() } +func TestOrderByGroupBy(t *testing.T) { + require := require.New(t) + + table := mem.NewTable("members", sql.Schema{ + {Name: "id", Type: sql.Int64, Source: "members"}, + {Name: "team", Type: sql.Text, Source: "members"}, + }) + require.NoError(table.Insert(sql.NewRow(int64(3), "red"))) + require.NoError(table.Insert(sql.NewRow(int64(4), "red"))) + require.NoError(table.Insert(sql.NewRow(int64(5), "orange"))) + require.NoError(table.Insert(sql.NewRow(int64(6), "orange"))) + require.NoError(table.Insert(sql.NewRow(int64(7), "orange"))) + require.NoError(table.Insert(sql.NewRow(int64(8), "purple"))) + + db := mem.NewDatabase("db") + db.AddTable(table.Name(), table) + + e := sqle.NewDefault() + e.AddDatabase(db) + + _, iter, err := e.Query( + sql.NewEmptyContext(), + "SELECT team, COUNT(*) FROM members GROUP BY team ORDER BY 2", + ) + require.NoError(err) + + rows, err := sql.RowIterToRows(iter) + require.NoError(err) + + expected := []sql.Row{ + {"purple", int32(1)}, + {"red", int32(2)}, + {"orange", int32(3)}, + } + + require.Equal(expected, rows) +} + func TestTracing(t *testing.T) { require := require.New(t) e := newEngine(t) - tracer := new(memTracer) + tracer := new(test.MemTracer) ctx := sql.NewContext(context.TODO(), sql.WithTracer(tracer)) @@ -747,13 +818,12 @@ func TestTracing(t *testing.T) { require.Len(rows, 1) require.NoError(err) - spans := tracer.spans - + spans := tracer.Spans var expectedSpans = []string{ "plan.Limit", + "plan.Sort", "plan.Distinct", "plan.Project", - "plan.Sort", "plan.Filter", "plan.PushdownProjectionAndFiltersTable", "expression.Equals", @@ -774,41 +844,3 @@ func TestTracing(t *testing.T) { require.Equal(expectedSpans, spanOperations) } - -type memTracer struct { - spans []string -} - -type memSpan struct { - opName string -} - -func (t *memTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span { - t.spans = append(t.spans, operationName) - return &memSpan{operationName} -} - -func (t *memTracer) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error { - panic("not implemented") -} - -func (t *memTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) { - panic("not implemented") -} - -func (m memSpan) Context() opentracing.SpanContext { return m } -func (m memSpan) SetBaggageItem(key, val string) opentracing.Span { return m } -func (m memSpan) BaggageItem(key string) string { return "" } -func (m memSpan) SetTag(key string, value interface{}) opentracing.Span { return m } -func (m memSpan) LogFields(fields ...log.Field) {} -func (m memSpan) LogKV(keyVals ...interface{}) {} -func (m memSpan) Finish() {} -func (m memSpan) FinishWithOptions(opts opentracing.FinishOptions) {} -func (m memSpan) SetOperationName(operationName string) opentracing.Span { - return &memSpan{operationName} -} -func (m memSpan) Tracer() opentracing.Tracer { return &memTracer{} } -func (m memSpan) LogEvent(event string) {} -func (m memSpan) LogEventWithPayload(event string, payload interface{}) {} -func (m memSpan) Log(data opentracing.LogData) {} -func (m memSpan) ForeachBaggageItem(handler func(k, v string) bool) {} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules.go index 9adce94e8..8bbdf0f7c 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules.go @@ -20,6 +20,7 @@ var DefaultRules = []Rule{ {"resolve_tables", resolveTables}, {"resolve_natural_joins", resolveNaturalJoins}, {"resolve_orderby_literals", resolveOrderByLiterals}, + {"resolve_orderby", resolveOrderBy}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, {"resolve_database", resolveDatabase}, @@ -28,6 +29,7 @@ var DefaultRules = []Rule{ {"reorder_projection", reorderProjection}, {"assign_indexes", assignIndexes}, {"pushdown", pushdown}, + {"move_join_conds_to_filter", moveJoinConditionsToFilter}, {"optimize_distinct", optimizeDistinct}, {"erase_projection", eraseProjection}, {"index_catalog", indexCatalog}, @@ -48,8 +50,172 @@ var ( // ErrOrderByColumnIndex is returned when in an order clause there is a // column that is unknown. ErrOrderByColumnIndex = errors.NewKind("unknown column %d in order by clause") + // ErrMisusedAlias is returned when a alias is defined and used in the same projection. + ErrMisusedAlias = errors.NewKind("column %q does not exist in scope, but there is an alias defined in" + + " this projection with that name. Aliases cannot be used in the same projection they're defined in") ) +func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, ctx := ctx.Span("resolve_orderby") + defer span.Finish() + + a.Log("resolving order bys, node of type: %T", n) + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + a.Log("transforming node of type: %T", n) + sort, ok := n.(*plan.Sort) + if !ok { + return n, nil + } + + if !sort.Child.Resolved() { + a.Log("child of type %T is not resolved yet, skipping", sort.Child) + return n, nil + } + + childNewCols := columnsDefinedInNode(sort.Child) + var schemaCols []string + for _, col := range sort.Child.Schema() { + schemaCols = append(schemaCols, col.Name) + } + + var colsFromChild []string + var missingCols []string + for _, f := range sort.SortFields { + n, ok := f.Column.(sql.Nameable) + if !ok { + continue + } + + if stringContains(childNewCols, n.Name()) { + colsFromChild = append(colsFromChild, n.Name()) + } else if !stringContains(schemaCols, n.Name()) { + missingCols = append(missingCols, n.Name()) + } + } + + // If all the columns required by the order by are available, do nothing about it. + if len(missingCols) == 0 { + a.Log("no missing columns, skipping") + return n, nil + } + + // If there are no columns required by the order by available, then move the order by + // below its child. + if len(colsFromChild) == 0 && len(missingCols) > 0 { + a.Log("pushing down sort, missing columns: %s", strings.Join(missingCols, ", ")) + return pushSortDown(sort) + } + + a.Log("fixing sort dependencies, missing columns: %s", strings.Join(missingCols, ", ")) + + // If there are some columns required by the order by on the child but some are missing + // we have to do some more complex logic and split the projection in two. + return fixSortDependencies(sort, missingCols) + }) +} + +// fixSortDependencies replaces the sort node by a node with the child projection +// followed by the sort, an intermediate projection or group by with all the missing +// columns required for the sort and then the child of the child projection or group by. +func fixSortDependencies(sort *plan.Sort, missingCols []string) (sql.Node, error) { + var expressions []sql.Expression + switch child := sort.Child.(type) { + case *plan.Project: + expressions = child.Projections + case *plan.GroupBy: + expressions = child.Aggregate + default: + return nil, errSortPushdown.New(child) + } + + var newExpressions = append([]sql.Expression{}, expressions...) + for _, col := range missingCols { + newExpressions = append(newExpressions, expression.NewUnresolvedColumn(col)) + } + + for i, e := range expressions { + var name string + if n, ok := e.(sql.Nameable); ok { + name = n.Name() + } else { + name = e.String() + } + + var table string + if t, ok := e.(sql.Tableable); ok { + table = t.Table() + } + expressions[i] = expression.NewGetFieldWithTable( + i, e.Type(), table, name, e.IsNullable(), + ) + } + + switch child := sort.Child.(type) { + case *plan.Project: + return plan.NewProject( + expressions, + plan.NewSort( + sort.SortFields, + plan.NewProject(newExpressions, child.Child), + ), + ), nil + case *plan.GroupBy: + return plan.NewProject( + expressions, + plan.NewSort( + sort.SortFields, + plan.NewGroupBy(newExpressions, child.Grouping, child.Child), + ), + ), nil + default: + return nil, errSortPushdown.New(child) + } +} + +// columnsDefinedInNode returns the columns that were defined in this node, +// which, by definition, can only be plan.Project or plan.GroupBy. +func columnsDefinedInNode(n sql.Node) []string { + var exprs []sql.Expression + switch n := n.(type) { + case *plan.Project: + exprs = n.Projections + case *plan.GroupBy: + exprs = n.Aggregate + } + + var cols []string + for _, e := range exprs { + alias, ok := e.(*expression.Alias) + if ok { + cols = append(cols, alias.Name()) + } + } + + return cols +} + +var errSortPushdown = errors.NewKind("unable to push plan.Sort node below %T") + +func pushSortDown(sort *plan.Sort) (sql.Node, error) { + switch child := sort.Child.(type) { + case *plan.Project: + return plan.NewProject( + child.Projections, + plan.NewSort(sort.SortFields, child.Child), + ), nil + case *plan.GroupBy: + return plan.NewGroupBy( + child.Aggregate, + child.Grouping, + plan.NewSort(sort.SortFields, child.Child), + ), nil + default: + // Can't do anything here, there should be either a project or a groupby + // below an order by. + return nil, errSortPushdown.New(child) + } +} + func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { span, ctx := ctx.Span("resolve_subqueries") defer span.Finish() @@ -79,6 +245,11 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node return n, nil } + // wait for the child to be resolved + if !sort.Child.Resolved() { + return n, nil + } + var fields = make([]plan.SortField, len(sort.SortFields)) for i, f := range sort.SortFields { if lit, ok := f.Column.(*expression.Literal); ok && sql.IsNumber(f.Column.Type()) { @@ -437,18 +608,18 @@ func expandStars(exprs []sql.Expression, schema sql.Schema) ([]sql.Expression, e return expressions, nil } -// maybeAlias is a wrapper on UnresolvedColumn used only to defer the -// resolution of the column because it could be an alias and that -// phase of the analyzer has not run yet. -type maybeAlias struct { +// deferredColumn is a wrapper on UnresolvedColumn used only to defer the +// resolution of the column because it may require some work done by +// other analyzer phases. +type deferredColumn struct { *expression.UnresolvedColumn } -func (e maybeAlias) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { +func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) { return fn(e) } -// column is the common interface that groups UnresolvedColumn and maybeAlias. +// column is the common interface that groups UnresolvedColumn and deferredColumn. type column interface { sql.Nameable sql.Tableable @@ -477,11 +648,29 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) } } + var aliasMap = map[string]struct{}{} + var exists = struct{}{} + if project, ok := n.(*plan.Project); ok { + for _, e := range project.Projections { + if alias, ok := e.(*expression.Alias); ok { + aliasMap[alias.Name()] = exists + } + } + } + expressioner, ok := n.(sql.Expressioner) if !ok { return n, nil } + // make sure all children are resolved before resolving a node + for _, c := range n.Children() { + if !c.Resolved() { + a.Log("a children with type %T of node %T were not resolved, skipping", c, n) + return n, nil + } + } + return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) if e.Resolved() { @@ -495,14 +684,19 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) columns, ok := colMap[uc.Name()] if !ok { - if uc.Table() != "" { - return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) - } - switch uc := uc.(type) { case *expression.UnresolvedColumn: - return &maybeAlias{uc}, nil + a.Log("evaluation of column %q was deferred", uc.Name()) + return &deferredColumn{uc}, nil default: + if uc.Table() != "" { + return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) + } + + if _, ok := aliasMap[uc.Name()]; ok { + return nil, ErrMisusedAlias.New(uc.Name()) + } + return nil, ErrColumnNotFound.New(uc.Name()) } } @@ -524,7 +718,7 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) switch uc := uc.(type) { case *expression.UnresolvedColumn: - return &maybeAlias{uc}, nil + return &deferredColumn{uc}, nil default: return nil, ErrColumnNotFound.New(uc.Name()) } @@ -1097,6 +1291,124 @@ func assignIndexes(ctx *sql.Context, a *Analyzer, node sql.Node) (sql.Node, erro }) } +func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + if !n.Resolved() { + a.Log("node is not resolved, skip moving join conditions to filter") + return n, nil + } + + a.Log("moving join conditions to filter, node of type: %T", n) + + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + join, ok := n.(*plan.InnerJoin) + if !ok { + return n, nil + } + + leftSources := nodeSources(join.Left) + rightSources := nodeSources(join.Right) + var leftFilters, rightFilters, condFilters []sql.Expression + for _, e := range splitExpression(join.Cond) { + sources := expressionSources(e) + + canMoveLeft := containsSources(leftSources, sources) + if canMoveLeft { + leftFilters = append(leftFilters, e) + } + + canMoveRight := containsSources(rightSources, sources) + if canMoveRight { + rightFilters = append(rightFilters, e) + } + + if !canMoveLeft && !canMoveRight { + condFilters = append(condFilters, e) + } + } + + var left, right sql.Node = join.Left, join.Right + if len(leftFilters) > 0 { + leftFilters, err := fixFieldIndexes(left.Schema(), expression.JoinAnd(leftFilters...)) + if err != nil { + return nil, err + } + + left = plan.NewFilter(leftFilters, left) + } + + if len(rightFilters) > 0 { + rightFilters, err := fixFieldIndexes(right.Schema(), expression.JoinAnd(rightFilters...)) + if err != nil { + return nil, err + } + + right = plan.NewFilter(rightFilters, right) + } + + if len(condFilters) > 0 { + return plan.NewInnerJoin( + left, right, + expression.JoinAnd(condFilters...), + ), nil + } + + // if there are no cond filters left we can just convert it to a cross join + return plan.NewCrossJoin(left, right), nil + }) +} + +// containsSources checks that all `needle` sources are contained inside `haystack`. +func containsSources(haystack, needle []string) bool { + for _, s := range needle { + var found bool + for _, s2 := range haystack { + if s2 == s { + found = true + break + } + } + + if !found { + return false + } + } + + return true +} + +func nodeSources(node sql.Node) []string { + var sources = make(map[string]struct{}) + var result []string + + for _, col := range node.Schema() { + if _, ok := sources[col.Source]; !ok { + sources[col.Source] = struct{}{} + result = append(result, col.Source) + } + } + + return result +} + +func expressionSources(expr sql.Expression) []string { + var sources = make(map[string]struct{}) + var result []string + + expression.Inspect(expr, func(expr sql.Expression) bool { + f, ok := expr.(*expression.GetField) + if ok { + if _, ok := sources[f.Table()]; !ok { + sources[f.Table()] = struct{}{} + result = append(result, f.Table()) + } + } + + return true + }) + + return result +} + func containsColumns(e sql.Expression) bool { var result bool expression.Inspect(e, func(e sql.Expression) bool { diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules_test.go index d06ec2ad4..57dbde5e0 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/analyzer/rules_test.go @@ -15,6 +15,234 @@ import ( "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) +func TestResolveOrderBy(t *testing.T) { + rule := getRule("resolve_orderby") + a := NewDefault(nil) + ctx := sql.NewEmptyContext() + + table := mem.NewTable("foo", sql.Schema{ + {Name: "a", Type: sql.Int64, Source: "foo"}, + {Name: "b", Type: sql.Int64, Source: "foo"}, + }) + + t.Run("with project", func(t *testing.T) { + require := require.New(t) + node := plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + table, + ), + ) + + result, err := rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(node, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + table, + ), + ) + + expected := plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + table, + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + table, + ), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "", "x", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + expression.NewUnresolvedColumn("a"), + }, + table, + ), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + }) + + t.Run("with group by", func(t *testing.T) { + require := require.New(t) + node := plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + table, + ), + ) + + result, err := rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(node, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + table, + ), + ) + + var expected sql.Node = plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + }, + table, + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + + node = plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + table, + ), + ) + + expected = plan.NewProject( + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "", "x", false), + }, + plan.NewSort( + []plan.SortField{ + {Column: expression.NewUnresolvedColumn("a")}, + {Column: expression.NewUnresolvedColumn("x")}, + }, + plan.NewGroupBy( + []sql.Expression{ + expression.NewAlias( + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + "x", + ), + expression.NewUnresolvedColumn("a"), + }, + []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false), + }, + table, + ), + ), + ) + + result, err = rule.Apply(ctx, a, node) + require.NoError(err) + + require.Equal(expected, result) + }) +} + func TestResolveSubqueries(t *testing.T) { require := require.New(t) @@ -452,6 +680,32 @@ func TestResolveStar(t *testing.T) { } } +func TestMisusedAlias(t *testing.T) { + require := require.New(t) + f := getRule("resolve_columns") + + table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}}) + + node := plan.NewProject( + []sql.Expression{ + expression.NewAlias( + expression.NewUnresolvedColumn("i"), + "alias_i", + ), + expression.NewUnresolvedColumn("alias_i"), + }, + table, + ) + + // the first iteration wrap the unresolved column "alias_i" as a maybeAlias + n, err := f.Apply(sql.NewEmptyContext(), nil, node) + require.NoError(err) + + // if maybeAlias is not resolved it fails + _, err = f.Apply(sql.NewEmptyContext(), nil, n) + require.EqualError(err, ErrMisusedAlias.New("alias_i").Error()) +} + func TestQualifyColumns(t *testing.T) { require := require.New(t) f := getRule("qualify_columns") @@ -1763,6 +2017,152 @@ func TestGetMultiColumnIndexes(t *testing.T) { require.Equal(expectedUsed, used) } +func TestContainsSources(t *testing.T) { + testCases := []struct { + name string + haystack []string + needle []string + expected bool + }{ + { + "needle is in haystack", + []string{"a", "b", "c"}, + []string{"c", "b"}, + true, + }, + { + "needle is not in haystack", + []string{"a", "b", "c"}, + []string{"d", "b"}, + false, + }, + { + "no elements in needle", + []string{"a", "b", "c"}, + nil, + true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require.Equal( + t, + containsSources(tt.haystack, tt.needle), + tt.expected, + ) + }) + } +} + +func TestNodeSources(t *testing.T) { + sources := nodeSources(mem.NewTable("foo", sql.Schema{ + {Source: "foo"}, + {Source: "foo"}, + {Source: "bar"}, + {Source: "baz"}, + })) + + expected := []string{"foo", "bar", "baz"} + require.Equal(t, expected, sources) +} + +func TestExpressionSources(t *testing.T) { + sources := expressionSources(expression.JoinAnd( + col(0, "foo", "bar"), + col(0, "foo", "qux"), + and( + eq( + col(0, "bar", "baz"), + lit(1), + ), + eq( + col(0, "baz", "baz"), + lit(2), + ), + ), + )) + + expected := []string{"foo", "bar", "baz"} + require.Equal(t, expected, sources) +} + +func TestMoveJoinConditionsToFilter(t *testing.T) { + t1 := mem.NewTable("t1", sql.Schema{ + {Name: "a", Source: "t1"}, + {Name: "b", Source: "t1"}, + }) + + t2 := mem.NewTable("t2", sql.Schema{ + {Name: "c", Source: "t2"}, + {Name: "d", Source: "t2"}, + }) + + t3 := mem.NewTable("t3", sql.Schema{ + {Name: "e", Source: "t3"}, + {Name: "f", Source: "t3"}, + }) + + rule := getRule("move_join_conds_to_filter") + require := require.New(t) + + node := plan.NewInnerJoin( + t1, + plan.NewCrossJoin(t2, t3), + expression.JoinAnd( + eq(col(0, "t1", "a"), col(2, "t2", "c")), + eq(col(0, "t1", "a"), col(4, "t3", "e")), + eq(col(2, "t2", "c"), col(4, "t3", "e")), + eq(col(0, "t1", "a"), lit(5)), + ), + ) + + result, err := rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + var expected sql.Node = plan.NewInnerJoin( + plan.NewFilter( + eq(col(0, "t1", "a"), lit(5)), + t1, + ), + plan.NewFilter( + eq(col(0, "t2", "c"), col(2, "t3", "e")), + plan.NewCrossJoin(t2, t3), + ), + and( + eq(col(0, "t1", "a"), col(2, "t2", "c")), + eq(col(0, "t1", "a"), col(4, "t3", "e")), + ), + ) + + require.Equal(expected, result) + + node = plan.NewInnerJoin( + t1, + plan.NewCrossJoin(t2, t3), + expression.JoinAnd( + eq(col(0, "t2", "c"), col(0, "t3", "e")), + eq(col(0, "t1", "a"), lit(5)), + ), + ) + + result, err = rule.Apply(sql.NewEmptyContext(), NewDefault(nil), node) + require.NoError(err) + + expected = plan.NewCrossJoin( + plan.NewFilter( + eq(col(0, "t1", "a"), lit(5)), + t1, + ), + plan.NewFilter( + eq(col(0, "t2", "c"), col(2, "t3", "e")), + plan.NewCrossJoin(t2, t3), + ), + ) + + require.Equal(result, expected) +} + func or(left, right sql.Expression) sql.Expression { return expression.NewOr(left, right) } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum.go new file mode 100644 index 000000000..fc7b6ca1b --- /dev/null +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum.go @@ -0,0 +1,82 @@ +package aggregation + +import ( + "fmt" + + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +// Sum agregation returns the sum of all values in the selected column. +// It implements the Aggregation interface. +type Sum struct { + expression.UnaryExpression +} + +// NewSum returns a new Sum node. +func NewSum(e sql.Expression) *Sum { + return &Sum{expression.UnaryExpression{Child: e}} +} + +// Type returns the resultant type of the aggregation. +func (m *Sum) Type() sql.Type { + return sql.Float64 +} + +func (m *Sum) String() string { + return fmt.Sprintf("SUM(%s)", m.Child) +} + +// TransformUp implements the Transformable interface. +func (m *Sum) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + child, err := m.Child.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewSum(child)) +} + +// NewBuffer creates a new buffer to compute the result. +func (m *Sum) NewBuffer() sql.Row { + return sql.NewRow(nil) +} + +// Update implements the Aggregation interface. +func (m *Sum) Update(ctx *sql.Context, buffer, row sql.Row) error { + v, err := m.Child.Eval(ctx, row) + if err != nil { + return err + } + + if v == nil { + return nil + } + + val, err := sql.Float64.Convert(v) + if err != nil { + val = float64(0) + } + + if buffer[0] == nil { + buffer[0] = float64(0) + } + + buffer[0] = buffer[0].(float64) + val.(float64) + + return nil +} + +// Merge implements the Aggregation interface. +func (m *Sum) Merge(ctx *sql.Context, buffer, partial sql.Row) error { + return m.Update(ctx, buffer, partial) +} + +// Eval implements the Aggregation interface. +func (m *Sum) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) { + span, ctx := ctx.Span("aggregation.Sum_Eval") + sum := buffer[0] + span.LogKV("sum", sum) + span.Finish() + + return sum, nil +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum_test.go new file mode 100644 index 000000000..f5d68970b --- /dev/null +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation/sum_test.go @@ -0,0 +1,75 @@ +package aggregation + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +func TestSum(t *testing.T) { + sum := NewSum(expression.NewGetField(0, nil, "", false)) + + testCases := []struct { + name string + rows []sql.Row + expected interface{} + }{ + { + "string int values", + []sql.Row{{"1"}, {"2"}, {"3"}, {"4"}}, + float64(10), + }, + { + "string float values", + []sql.Row{{"1.5"}, {"2"}, {"3"}, {"4"}}, + float64(10.5), + }, + { + "string non-int values", + []sql.Row{{"a"}, {"b"}, {"c"}, {"d"}}, + float64(0), + }, + { + "float values", + []sql.Row{{1.}, {2.5}, {3.}, {4.}}, + float64(10.5), + }, + { + "no rows", + []sql.Row{}, + nil, + }, + { + "nil values", + []sql.Row{{nil}, {nil}}, + nil, + }, + { + "int64 values", + []sql.Row{{int64(1)}, {int64(3)}}, + float64(4), + }, + { + "int32 values", + []sql.Row{{int32(1)}, {int32(3)}}, + float64(4), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + buf := sum.NewBuffer() + for _, row := range tt.rows { + require.NoError(sum.Update(sql.NewEmptyContext(), buf, row)) + } + + result, err := sum.Eval(sql.NewEmptyContext(), buf) + require.NoError(err) + require.Equal(tt.expected, result) + }) + } +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/registry.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/registry.go index 6cf51a210..a18ae0e2a 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/registry.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/registry.go @@ -19,6 +19,9 @@ var Defaults = sql.Functions{ "avg": sql.Function1(func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) }), + "sum": sql.Function1(func(e sql.Expression) sql.Expression { + return aggregation.NewSum(e) + }), "is_binary": sql.Function1(NewIsBinary), "substring": sql.FunctionN(NewSubstring), "year": sql.Function1(NewYear), diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index.go index b7533d484..9a80f025b 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index.go @@ -2,7 +2,6 @@ package sql import ( "bytes" - "context" "encoding/hex" "io" "strings" @@ -52,27 +51,27 @@ type Index interface { // AscendIndex is an index that is sorted in ascending order. type AscendIndex interface { // AscendGreaterOrEqual returns an IndexLookup for keys that are greater - // or equal to the given key. - AscendGreaterOrEqual(key interface{}) (IndexLookup, error) + // or equal to the given keys. + AscendGreaterOrEqual(keys ...interface{}) (IndexLookup, error) // AscendLessThan returns an IndexLookup for keys that are less than the - // given key. - AscendLessThan(key interface{}) (IndexLookup, error) + // given keys. + AscendLessThan(keys ...interface{}) (IndexLookup, error) // AscendRange returns an IndexLookup for keys that are within the given // range. - AscendRange(greaterOrEqual, lessThan interface{}) (IndexLookup, error) + AscendRange(greaterOrEqual, lessThan []interface{}) (IndexLookup, error) } // DescendIndex is an index that is sorted in descending order. type DescendIndex interface { // DescendGreater returns an IndexLookup for keys that are greater - // than the given key. - DescendGreater(key interface{}) (IndexLookup, error) + // than the given keys. + DescendGreater(keys ...interface{}) (IndexLookup, error) // DescendLessOrEqual returns an IndexLookup for keys that are less than or - // equal to the given key. - DescendLessOrEqual(key interface{}) (IndexLookup, error) + // equal to the given keys. + DescendLessOrEqual(keys ...interface{}) (IndexLookup, error) // DescendRange returns an IndexLookup for keys that are within the given // range. - DescendRange(lessOrEqual, greaterThan interface{}) (IndexLookup, error) + DescendRange(lessOrEqual, greaterThan []interface{}) (IndexLookup, error) } // IndexLookup is a subset of an index. More specific interfaces can be @@ -116,7 +115,7 @@ type IndexDriver interface { // LoadAll loads all indexes for given db and table LoadAll(db, table string) ([]Index, error) // Save the given index - Save(ctx context.Context, index Index, iter IndexKeyValueIter) error + Save(ctx *Context, index Index, iter IndexKeyValueIter) error // Delete the given index. Delete(index Index) error } @@ -492,12 +491,14 @@ func (r *IndexRegistry) AddIndex(idx Index) (chan<- struct{}, error) { // the index for deletion but does not remove it, so queries that are using it // may still do so. The returned channel will send a message when the index can // be deleted from disk. -func (r *IndexRegistry) DeleteIndex(db, id string) (<-chan struct{}, error) { +// If force is true, it will delete the index even if it's not ready for usage. +// Only use that parameter if you know what you're doing. +func (r *IndexRegistry) DeleteIndex(db, id string, force bool) (<-chan struct{}, error) { r.mut.RLock() var key indexKey for k, idx := range r.indexes { if strings.ToLower(id) == idx.ID() { - if !r.CanUseIndex(idx) { + if !force && !r.CanUseIndex(idx) { r.mut.RUnlock() return nil, ErrIndexDeleteInvalidStatus.New(id) } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver.go index a706d8448..df53707e3 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver.go @@ -1,7 +1,6 @@ package pilosa import ( - "context" "crypto/sha1" "fmt" "io" @@ -23,6 +22,8 @@ const ( IndexNamePrefix = "idx" // FrameNamePrefix the pilosa's frames prefix FrameNamePrefix = "frm" + // BatchSize is the number of objects to save when creating indexes. + BatchSize = 10000 ) // Driver implements sql.IndexDriver interface. @@ -136,10 +137,60 @@ func (d *Driver) loadIndex(path string) (sql.Index, error) { return idx, nil } -var errInvalidIndexType = errors.NewKind("expecting a pilosa index, instead got %T") +var ( + errInvalidIndexType = errors.NewKind("expecting a pilosa index, instead got %T") + errDeletePilosaIndex = errors.NewKind("error deleting pilosa index %s: %s") +) + +type bitBatch struct { + size uint64 + bits []pilosa.Bit + pos uint64 +} + +func newBitBatch(size uint64) *bitBatch { + b := &bitBatch{size: size} + b.Clean() + + return b +} + +func (b *bitBatch) Clean() { + b.bits = make([]pilosa.Bit, 0, b.size) + b.pos = 0 +} + +func (b *bitBatch) Add(row, col uint64) { + b.bits = append(b.bits, pilosa.Bit{ + RowID: row, + ColumnID: col, + }) +} + +func (b *bitBatch) NextRecord() (pilosa.Record, error) { + if b.pos >= uint64(len(b.bits)) { + return nil, io.EOF + } + + b.pos++ + return b.bits[b.pos-1], nil +} + +func (b *bitBatch) Send(frame *pilosa.Frame, client *pilosa.Client) error { + return client.ImportFrame(frame, b) +} // Save the given index (mapping and bitmap) -func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIter) error { +func (d *Driver) Save( + ctx *sql.Context, + i sql.Index, + iter sql.IndexKeyValueIter, +) (err error) { + span, ctx := ctx.Span("pilosa.Save") + span.LogKV("name", i.ID()) + + defer span.Finish() + idx, ok := i.(*pilosaIndex) if !ok { return errInvalidIndexType.New(i) @@ -150,7 +201,7 @@ func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIt return err } - if err := index.CreateProcessingFile(path); err != nil { + if err = index.CreateProcessingFile(path); err != nil { return err } @@ -166,6 +217,12 @@ func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIt return err } + // make sure we delete the index in every run before inserting, since there may + // be previous data + if err = d.client.DeleteIndex(pilosaIndex); err != nil { + return errDeletePilosaIndex.New(pilosaIndex.Name(), err) + } + frames := make([]*pilosa.Frame, len(idx.ExpressionHashes())) for i, e := range idx.ExpressionHashes() { frames[i], err = pilosaIndex.Frame(frameName(e)) @@ -180,10 +237,46 @@ func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIt return err } - idx.mapping.open() - defer idx.mapping.close() + // Open mapping in create mode. After finishing the transaction is rolled + // back unless all goes well and rollback value is changed. + rollback := true + idx.mapping.openCreate(true) + defer func() { + if rollback { + idx.mapping.rollback() + } else { + e := idx.mapping.commit(false) + if e != nil && err == nil { + err = e + } + } + + idx.mapping.close() + }() + + bitBatch := make([]*bitBatch, len(frames)) + for i := range bitBatch { + bitBatch[i] = newBitBatch(BatchSize) + } for colID := uint64(0); err == nil; colID++ { + // commit each batch of objects (pilosa and boltdb) + if colID%BatchSize == 0 && colID != 0 { + for i, frm := range frames { + err = bitBatch[i].Send(frm, d.client) + if err != nil { + return err + } + + bitBatch[i].Clean() + } + + err = idx.mapping.commit(true) + if err != nil { + return err + } + } + select { case <-ctx.Done(): return ctx.Err() @@ -199,18 +292,16 @@ func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIt } for i, frm := range frames { - rowID, err := idx.mapping.getRowID(frm.Name(), values[i]) - if err != nil { - return err + if values[i] == nil { + continue } - resp, err := d.client.Query(frm.SetBit(rowID, colID)) + rowID, err := idx.mapping.getRowID(frm.Name(), values[i]) if err != nil { return err } - if !resp.Success { - return errPilosaQuery.New(resp.ErrorMessage) - } + + bitBatch[i].Add(rowID, colID) } err = idx.mapping.putLocation(pilosaIndex.Name(), colID, location) } @@ -220,6 +311,15 @@ func (d *Driver) Save(ctx context.Context, i sql.Index, iter sql.IndexKeyValueIt return err } + rollback = false + + for i, frm := range frames { + err = bitBatch[i].Send(frm, d.client) + if err != nil { + return err + } + } + return index.RemoveProcessingFile(path) } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver_test.go index 6c52c49e2..2e1f8d0e6 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/driver_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/index" + "gopkg.in/src-d/go-mysql-server.v0/test" ) // Pilosa tests require running docker. If `docker ps` command returned an error @@ -103,7 +104,9 @@ func TestSaveAndLoad(t *testing.T) { location: randLocation, } - err = d.Save(context.Background(), sqlIdx, it) + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.Background(), sql.WithTracer(tracer)) + err = d.Save(ctx, sqlIdx, it) require.NoError(err) indexes, err := d.LoadAll(db, table) @@ -137,6 +140,16 @@ func TestSaveAndLoad(t *testing.T) { err = lit.Close() require.NoError(err) } + + found := false + for _, span := range tracer.Spans { + if span == "pilosa.Save" { + found = true + break + } + } + + require.True(found) } func TestSaveAndGetAll(t *testing.T) { @@ -162,7 +175,7 @@ func TestSaveAndGetAll(t *testing.T) { location: randLocation, } - err = d.Save(context.Background(), sqlIdx, it) + err = d.Save(sql.NewEmptyContext(), sqlIdx, it) require.NoError(err) indexes, err := d.LoadAll(db, table) @@ -223,7 +236,7 @@ func TestPilosaHiccup(t *testing.T) { // retry save index - if pilosa failed, reset iterator and start over err = retry(ctx, func() error { - if e := d.Save(ctx, sqlIdx, it); e != nil { + if e := d.Save(sql.NewContext(ctx), sqlIdx, it); e != nil { t.Logf("Save err: %s", e) // reset iterator! it.Close() @@ -310,6 +323,141 @@ func TestLoadAllDirectoryDoesNotExist(t *testing.T) { require.Len(drivers, 0) } +func TestAscendDescendIndex(t *testing.T) { + idx, cleanup := setupAscendDescend(t) + defer cleanup() + + must := func(lookup sql.IndexLookup, err error) sql.IndexLookup { + require.NoError(t, err) + return lookup + } + + testCases := []struct { + name string + lookup sql.IndexLookup + expected []string + }{ + { + "ascend range", + must(idx.AscendRange( + []interface{}{int64(1), int64(1)}, + []interface{}{int64(7), int64(10)}, + )), + []string{"1", "5", "6", "7", "8", "9"}, + }, + { + "ascend greater or equal", + must(idx.AscendGreaterOrEqual(int64(7), int64(6))), + []string{"2", "4"}, + }, + { + "ascend less than", + must(idx.AscendLessThan(int64(5), int64(3))), + []string{"1", "10"}, + }, + { + "descend range", + must(idx.DescendRange( + []interface{}{int64(6), int64(9)}, + []interface{}{int64(0), int64(0)}, + )), + []string{"9", "8", "7", "6", "5", "1"}, + }, + { + "descend less or equal", + must(idx.DescendLessOrEqual(int64(4), int64(2))), + []string{"10", "1"}, + }, + { + "descend greater", + must(idx.DescendGreater(int64(6), int64(5))), + []string{"4", "2"}, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + iter, err := tt.lookup.Values() + require.NoError(err) + + var result []string + for { + k, err := iter.Next() + if err == io.EOF { + break + } + require.NoError(err) + + result = append(result, string(k)) + } + + require.Equal(tt.expected, result) + }) + } +} + +func setupAscendDescend(t *testing.T) (*pilosaIndex, func()) { + t.Helper() + if !dockerIsRunning { + t.Skipf("Skip test: %s", dockerCmdOutput) + } + require := require.New(t) + + db, table, id := "db_name", "table_name", "index_id" + expressions := makeExpressions("a", "b") + path, err := mkdir(os.TempDir(), "indexes") + require.NoError(err) + + d := NewDriver(path, newClientWithTimeout(200*time.Millisecond)) + sqlIdx, err := d.Create(db, table, id, expressions, nil) + require.NoError(err) + + it := &fixtureKeyValueIter{ + fixtures: []kvfixture{ + {"9", []interface{}{int64(2), int64(6)}}, + {"3", []interface{}{int64(7), int64(5)}}, + {"1", []interface{}{int64(1), int64(2)}}, + {"7", []interface{}{int64(1), int64(3)}}, + {"4", []interface{}{int64(7), int64(6)}}, + {"2", []interface{}{int64(10), int64(6)}}, + {"5", []interface{}{int64(5), int64(1)}}, + {"6", []interface{}{int64(6), int64(2)}}, + {"10", []interface{}{int64(4), int64(0)}}, + {"8", []interface{}{int64(3), int64(5)}}, + }, + } + + err = d.Save(sql.NewEmptyContext(), sqlIdx, it) + require.NoError(err) + + return sqlIdx.(*pilosaIndex), func() { + require.NoError(os.RemoveAll(path)) + } +} + +type kvfixture struct { + key string + values []interface{} +} + +type fixtureKeyValueIter struct { + fixtures []kvfixture + pos int +} + +func (i *fixtureKeyValueIter) Next() ([]interface{}, []byte, error) { + if i.pos >= len(i.fixtures) { + return nil, nil, io.EOF + } + + f := i.fixtures[i.pos] + i.pos++ + return f.values, []byte(f.key), nil +} + +func (i *fixtureKeyValueIter) Close() error { return nil } + // test implementation of sql.IndexKeyValueIter interface type testIndexKeyValueIter struct { offset int @@ -391,7 +539,7 @@ func newClientWithTimeout(timeout time.Duration) *pilosa.Client { func retry(ctx context.Context, fn func() error) error { var ( backoffDuration = 200 * time.Millisecond - maxRetries = 5 + maxRetries = 10 err error ) diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index.go index 340667370..6646496cc 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index.go @@ -1,7 +1,11 @@ package pilosa import ( + "bytes" + "encoding/gob" "io" + "strings" + "time" "gopkg.in/src-d/go-errors.v1" @@ -106,6 +110,119 @@ func (idx *pilosaIndex) ExpressionHashes() []sql.ExpressionHash { func (pilosaIndex) Driver() string { return DriverID } +func (idx *pilosaIndex) AscendGreaterOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &ascendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + }, + gte: keys, + lt: nil, + }, nil +} + +func (idx *pilosaIndex) AscendLessThan(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &ascendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + }, + gte: nil, + lt: keys, + }, nil +} + +func (idx *pilosaIndex) AscendRange(greaterOrEqual, lessThan []interface{}) (sql.IndexLookup, error) { + if len(greaterOrEqual) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(greaterOrEqual)) + } + + if len(lessThan) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(lessThan)) + } + + return &ascendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + }, + gte: greaterOrEqual, + lt: lessThan, + }, nil +} + +func (idx *pilosaIndex) DescendGreater(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &descendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + reverse: true, + }, + gt: keys, + lte: nil, + }, nil +} + +func (idx *pilosaIndex) DescendLessOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + if len(keys) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(keys)) + } + + return &descendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + reverse: true, + }, + gt: nil, + lte: keys, + }, nil +} + +func (idx *pilosaIndex) DescendRange(lessOrEqual, greaterThan []interface{}) (sql.IndexLookup, error) { + if len(lessOrEqual) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(lessOrEqual)) + } + + if len(greaterThan) != len(idx.expressions) { + return nil, errInvalidKeys.New(len(idx.expressions), idx.ID(), len(greaterThan)) + } + + return &descendLookup{ + filteredLookup: &filteredLookup{ + indexName: indexName(idx.Database(), idx.Table(), idx.ID()), + mapping: idx.mapping, + client: idx.client, + expressions: idx.expressions, + reverse: true, + }, + gt: greaterThan, + lte: lessOrEqual, + }, nil +} + type indexLookup struct { keys []interface{} indexName string @@ -135,7 +252,7 @@ func (l *indexLookup) Values() (sql.IndexValueIter, error) { return nil, err } - rowID, err := l.mapping.getRowID(frm.Name(), l.keys[i]) + rowID, err := l.mapping.rowID(frm.Name(), l.keys[i]) if err != nil { return nil, err } @@ -165,6 +282,189 @@ func (l *indexLookup) Values() (sql.IndexValueIter, error) { }, nil } +type filteredLookup struct { + indexName string + mapping *mapping + client *pilosa.Client + expressions []sql.ExpressionHash + reverse bool +} + +func (l *filteredLookup) values(filter func(int, []byte) (bool, error)) (sql.IndexValueIter, error) { + l.mapping.open() + defer l.mapping.close() + + schema, err := l.client.Schema() + if err != nil { + return nil, err + } + + index, err := schema.Index(l.indexName) + if err != nil { + return nil, err + } + + // Compute Intersection of bitmaps + var bitmaps []*pilosa.PQLBitmapQuery + for i := 0; i < len(l.expressions); i++ { + frm, err := index.Frame(frameName(l.expressions[i])) + if err != nil { + return nil, err + } + + rows, err := l.mapping.filter(frm.Name(), func(b []byte) (bool, error) { + return filter(i, b) + }) + + if err != nil { + return nil, err + } + + var bs []*pilosa.PQLBitmapQuery + for _, row := range rows { + bs = append(bs, frm.Bitmap(row)) + } + + bitmaps = append(bitmaps, index.Union(bs...)) + } + + resp, err := l.client.Query(index.Intersect(bitmaps...)) + if err != nil { + return nil, err + } + + if !resp.Success { + return nil, errPilosaQuery.New(resp.ErrorMessage) + } + + if resp.Result() == nil { + return &indexValueIter{mapping: l.mapping, indexName: l.indexName}, nil + } + + bits := resp.Result().Bitmap().Bits + locations, err := l.mapping.sortedLocations(l.indexName, bits, l.reverse) + if err != nil { + return nil, err + } + + return &locationValueIter{locations: locations}, nil +} + +type locationValueIter struct { + locations [][]byte + pos int +} + +func (i *locationValueIter) Next() ([]byte, error) { + if i.pos >= len(i.locations) { + return nil, io.EOF + } + + i.pos++ + return i.locations[i.pos-1], nil +} + +func (i *locationValueIter) Close() error { + i.locations = nil + return nil +} + +type ascendLookup struct { + *filteredLookup + gte []interface{} + lt []interface{} +} + +func (l *ascendLookup) Values() (sql.IndexValueIter, error) { + return l.values(func(i int, value []byte) (bool, error) { + var v interface{} + var err error + if len(l.gte) > 0 { + v, err = decodeGob(value, l.gte[i]) + if err != nil { + return false, err + } + + cmp, err := compare(v, l.gte[i]) + if err != nil { + return false, err + } + + if cmp < 0 { + return false, nil + } + } + + if len(l.lt) > 0 { + if v == nil { + v, err = decodeGob(value, l.lt[i]) + if err != nil { + return false, err + } + } + + cmp, err := compare(v, l.lt[i]) + if err != nil { + return false, err + } + + if cmp >= 0 { + return false, nil + } + } + + return true, nil + }) +} + +type descendLookup struct { + *filteredLookup + gt []interface{} + lte []interface{} +} + +func (l *descendLookup) Values() (sql.IndexValueIter, error) { + return l.values(func(i int, value []byte) (bool, error) { + var v interface{} + var err error + if len(l.gt) > 0 { + v, err = decodeGob(value, l.gt[i]) + if err != nil { + return false, err + } + + cmp, err := compare(v, l.gt[i]) + if err != nil { + return false, err + } + + if cmp <= 0 { + return false, nil + } + } + + if len(l.lte) > 0 { + if v == nil { + v, err = decodeGob(value, l.lte[i]) + if err != nil { + return false, err + } + } + + cmp, err := compare(v, l.lte[i]) + if err != nil { + return false, err + } + + if cmp > 0 { + return false, nil + } + } + + return true, nil + }) +} + type indexValueIter struct { offset uint64 total uint64 @@ -195,3 +495,210 @@ func (it *indexValueIter) Next() ([]byte, error) { } func (it *indexValueIter) Close() error { return it.mapping.close() } + +var ( + errUnknownType = errors.NewKind("unknown type %T received as value") + errTypeMismatch = errors.NewKind("cannot compare type %T with type %T") +) + +func decodeGob(k []byte, value interface{}) (interface{}, error) { + decoder := gob.NewDecoder(bytes.NewBuffer(k)) + + switch value.(type) { + case string: + var v string + err := decoder.Decode(&v) + return v, err + case int32: + var v int32 + err := decoder.Decode(&v) + return v, err + case int64: + var v int64 + err := decoder.Decode(&v) + return v, err + case uint32: + var v uint32 + err := decoder.Decode(&v) + return v, err + case uint64: + var v uint64 + err := decoder.Decode(&v) + return v, err + case float64: + var v float64 + err := decoder.Decode(&v) + return v, err + case time.Time: + var v time.Time + err := decoder.Decode(&v) + return v, err + case []byte: + var v []byte + err := decoder.Decode(&v) + return v, err + case bool: + var v bool + err := decoder.Decode(&v) + return v, err + case []interface{}: + var v []interface{} + err := decoder.Decode(&v) + return v, err + default: + return nil, errUnknownType.New(value) + } +} + +// compare two values of the same underlying type. The values MUST be of the +// same type. +func compare(a, b interface{}) (int, error) { + switch a := a.(type) { + case bool: + v, ok := b.(bool) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a == false { + return -1, nil + } + + return 1, nil + case string: + v, ok := b.(string) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + return strings.Compare(a, v), nil + case int32: + v, ok := b.(int32) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case int64: + v, ok := b.(int64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint32: + v, ok := b.(uint32) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case uint64: + v, ok := b.(uint64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case float64: + v, ok := b.(float64) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a == v { + return 0, nil + } + + if a < v { + return -1, nil + } + + return 1, nil + case []byte: + v, ok := b.([]byte) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + return bytes.Compare(a, v), nil + case []interface{}: + v, ok := b.([]interface{}) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if len(a) < len(v) { + return -1, nil + } + + if len(a) > len(v) { + return 1, nil + } + + for i := range a { + cmp, err := compare(a[i], v[i]) + if err != nil { + return 0, err + } + + if cmp != 0 { + return cmp, nil + } + } + + return 0, nil + case time.Time: + v, ok := b.(time.Time) + if !ok { + return 0, errTypeMismatch.New(a, b) + } + + if a.Equal(v) { + return 0, nil + } + + if a.Before(v) { + return -1, nil + } + + return 1, nil + default: + return 0, errUnknownType.New(a) + } +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index_test.go new file mode 100644 index 000000000..e4ea7364d --- /dev/null +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/index_test.go @@ -0,0 +1,120 @@ +package pilosa + +import ( + "bytes" + "encoding/gob" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + errors "gopkg.in/src-d/go-errors.v1" +) + +func TestCompare(t *testing.T) { + now := time.Now() + testCases := []struct { + a, b interface{} + err *errors.Kind + expected int + }{ + {true, true, nil, 0}, + {false, true, nil, -1}, + {true, false, nil, 1}, + {false, false, nil, 0}, + {true, 0, errTypeMismatch, 0}, + + {"a", "b", nil, -1}, + {"b", "a", nil, 1}, + {"a", "a", nil, 0}, + {"a", 1, errTypeMismatch, 0}, + + {int32(1), int32(2), nil, -1}, + {int32(2), int32(1), nil, 1}, + {int32(2), int32(2), nil, 0}, + {int32(1), "", errTypeMismatch, 0}, + + {int64(1), int64(2), nil, -1}, + {int64(2), int64(1), nil, 1}, + {int64(2), int64(2), nil, 0}, + {int64(1), "", errTypeMismatch, 0}, + + {uint32(1), uint32(2), nil, -1}, + {uint32(2), uint32(1), nil, 1}, + {uint32(2), uint32(2), nil, 0}, + {uint32(1), "", errTypeMismatch, 0}, + + {uint64(1), uint64(2), nil, -1}, + {uint64(2), uint64(1), nil, 1}, + {uint64(2), uint64(2), nil, 0}, + {uint64(1), "", errTypeMismatch, 0}, + + {float64(1), float64(2), nil, -1}, + {float64(2), float64(1), nil, 1}, + {float64(2), float64(2), nil, 0}, + {float64(1), "", errTypeMismatch, 0}, + + {now.Add(-1 * time.Hour), now, nil, -1}, + {now, now.Add(-1 * time.Hour), nil, 1}, + {now, now, nil, 0}, + {now, 1, errTypeMismatch, -1}, + + {[]interface{}{"a", "a"}, []interface{}{"a", "b"}, nil, -1}, + {[]interface{}{"a", "b"}, []interface{}{"a", "a"}, nil, 1}, + {[]interface{}{"a", "a"}, []interface{}{"a", "a"}, nil, 0}, + {[]interface{}{"b"}, []interface{}{"a", "b"}, nil, -1}, + {[]interface{}{"b"}, 1, errTypeMismatch, -1}, + + {[]byte{0, 1}, []byte{1, 1}, nil, -1}, + {[]byte{1, 1}, []byte{0, 1}, nil, 1}, + {[]byte{1, 1}, []byte{1, 1}, nil, 0}, + {[]byte{1}, []byte{0, 1}, nil, 1}, + {[]byte{0, 1}, 1, errTypeMismatch, -1}, + + {time.Duration(0), nil, errUnknownType, -1}, + } + + for _, tt := range testCases { + name := fmt.Sprintf("(%T)(%v) and (%T)(%v)", tt.a, tt.a, tt.b, tt.b) + t.Run(name, func(t *testing.T) { + require := require.New(t) + cmp, err := compare(tt.a, tt.b) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err)) + } else { + require.NoError(err) + require.Equal(tt.expected, cmp) + } + }) + } +} + +func TestDecodeGob(t *testing.T) { + testCases := []interface{}{ + "foo", + int32(1), + int64(1), + uint32(1), + uint64(1), + float64(1), + true, + time.Date(2018, time.August, 1, 1, 1, 1, 1, time.Local), + []byte("foo"), + []interface{}{1, 3, 3, 7}, + } + + for _, tt := range testCases { + name := fmt.Sprintf("(%T)(%v)", tt, tt) + t.Run(name, func(t *testing.T) { + require := require.New(t) + + var buf bytes.Buffer + require.NoError(gob.NewEncoder(&buf).Encode(tt)) + + result, err := decodeGob(buf.Bytes(), tt) + require.NoError(err) + require.Equal(tt, result) + }) + } +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping.go index dd640d182..9f3bb46c2 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping.go @@ -6,6 +6,7 @@ import ( "encoding/gob" "fmt" "path/filepath" + "sort" "sync" "github.com/boltdb/bolt" @@ -25,6 +26,11 @@ type mapping struct { mut sync.RWMutex db *bolt.DB + // in create mode there's only one transaction closed explicitly by + // commit function + create bool + tx *bolt.Tx + clientMut sync.Mutex clients int } @@ -34,9 +40,15 @@ func newMapping(dir string) *mapping { } func (m *mapping) open() { + m.openCreate(false) +} + +// openCreate opens and sets creation mode in the database. +func (m *mapping) openCreate(create bool) { m.clientMut.Lock() defer m.clientMut.Unlock() m.clients++ + m.create = create } func (m *mapping) close() error { @@ -80,6 +92,87 @@ func (m *mapping) query(fn func() error) error { return fn() } +func (m *mapping) rowID(frameName string, value interface{}) (uint64, error) { + val, err := m.get(frameName, value) + if err != nil { + return 0, err + } + if val == nil { + return 0, fmt.Errorf("id is nil") + } + + return binary.LittleEndian.Uint64(val), err +} + +// commit saves current transaction, if cont is true a new transaction will be +// created again in the next query. Only for create mode. +func (m *mapping) commit(cont bool) error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + + var err error + if m.create && m.tx != nil { + err = m.tx.Commit() + } + + m.create = cont + m.tx = nil + + return err +} + +func (m *mapping) rollback() error { + m.clientMut.Lock() + defer m.clientMut.Unlock() + + var err error + if m.create && m.tx != nil { + err = m.tx.Rollback() + } + + m.create = false + m.tx = nil + + return err +} + +func (m *mapping) transaction(writable bool, f func(*bolt.Tx) error) error { + var tx *bolt.Tx + var err error + if m.create { + m.clientMut.Lock() + if m.tx == nil { + m.tx, err = m.db.Begin(true) + if err != nil { + m.clientMut.Unlock() + return err + } + } + + m.clientMut.Unlock() + + tx = m.tx + } else { + tx, err = m.db.Begin(writable) + if err != nil { + return err + } + } + + err = f(tx) + + if m.create { + return err + } + + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + func (m *mapping) getRowID(frameName string, value interface{}) (uint64, error) { var id uint64 err := m.query(func() error { @@ -90,7 +183,7 @@ func (m *mapping) getRowID(frameName string, value interface{}) (uint64, error) return err } - err = m.db.Update(func(tx *bolt.Tx) error { + err = m.transaction(true, func(tx *bolt.Tx) error { b, err := tx.CreateBucketIfNotExists([]byte(frameName)) if err != nil { return err @@ -103,7 +196,10 @@ func (m *mapping) getRowID(frameName string, value interface{}) (uint64, error) return nil } - id = uint64(b.Stats().KeyN) + // the first NextSequence is 1 so the first id will be 1 + // this can only fail if the transaction is closed + id, _ = b.NextSequence() + val = make([]byte, 8) binary.LittleEndian.PutUint64(val, id) err = b.Put(key, val) @@ -118,7 +214,7 @@ func (m *mapping) getRowID(frameName string, value interface{}) (uint64, error) func (m *mapping) putLocation(indexName string, colID uint64, location []byte) error { return m.query(func() error { - return m.db.Update(func(tx *bolt.Tx) error { + return m.transaction(true, func(tx *bolt.Tx) error { b, err := tx.CreateBucketIfNotExists([]byte(indexName)) if err != nil { return err @@ -132,14 +228,57 @@ func (m *mapping) putLocation(indexName string, colID uint64, location []byte) e }) } +func (m *mapping) sortedLocations(indexName string, cols []uint64, reverse bool) ([][]byte, error) { + var result [][]byte + err := m.query(func() error { + return m.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(indexName)) + if b == nil { + return fmt.Errorf("bucket %s not found", indexName) + } + + for _, col := range cols { + key := make([]byte, 8) + binary.LittleEndian.PutUint64(key, col) + val := b.Get(key) + + // val will point to mmap addresses, so we need to copy the slice + dst := make([]byte, len(val)) + copy(dst, val) + result = append(result, dst) + } + + return nil + }) + }) + + if err != nil { + return nil, err + } + + if reverse { + sort.Stable(sort.Reverse(byBytes(result))) + } else { + sort.Stable(byBytes(result)) + } + + return result, nil +} + +type byBytes [][]byte + +func (b byBytes) Len() int { return len(b) } +func (b byBytes) Swap(i, j int) { b[i], b[j] = b[j], b[i] } +func (b byBytes) Less(i, j int) bool { return bytes.Compare(b[i], b[j]) < 0 } + func (m *mapping) getLocation(indexName string, colID uint64) ([]byte, error) { var location []byte err := m.query(func() error { - err := m.db.View(func(tx *bolt.Tx) error { + err := m.transaction(true, func(tx *bolt.Tx) error { b := tx.Bucket([]byte(indexName)) if b == nil { - return fmt.Errorf("Bucket %s not found", indexName) + return fmt.Errorf("bucket %s not found", indexName) } key := make([]byte, 8) @@ -158,7 +297,7 @@ func (m *mapping) getLocation(indexName string, colID uint64) ([]byte, error) { func (m *mapping) getLocationN(indexName string) (int, error) { var n int err := m.query(func() error { - err := m.db.View(func(tx *bolt.Tx) error { + err := m.transaction(false, func(tx *bolt.Tx) error { b := tx.Bucket([]byte(indexName)) if b == nil { return fmt.Errorf("Bucket %s not found", indexName) @@ -184,7 +323,7 @@ func (m *mapping) get(name string, key interface{}) ([]byte, error) { return err } - err = m.db.View(func(tx *bolt.Tx) error { + err = m.transaction(true, func(tx *bolt.Tx) error { b := tx.Bucket([]byte(name)) if b != nil { value = b.Get(buf.Bytes()) @@ -198,3 +337,31 @@ func (m *mapping) get(name string, key interface{}) ([]byte, error) { }) return value, err } + +func (m *mapping) filter(name string, fn func([]byte) (bool, error)) ([]uint64, error) { + var result []uint64 + + err := m.query(func() error { + return m.db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte(name)) + if b == nil { + return nil + } + + return b.ForEach(func(k, v []byte) error { + ok, err := fn(k) + if err != nil { + return err + } + + if ok { + result = append(result, binary.LittleEndian.Uint64(v)) + } + + return nil + }) + }) + }) + + return result, err +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping_test.go index 27a5a3e7a..6d07da5d5 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index/pilosa/mapping_test.go @@ -20,7 +20,7 @@ func TestRowID(t *testing.T) { defer m.close() cases := []int{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} - expected := []uint64{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} + expected := []uint64{1, 2, 3, 4, 5, 6, 6, 1, 4, 3, 2, 6} for i, c := range cases { rowID, err := m.getRowID("frame name", c) @@ -72,7 +72,7 @@ func TestGet(t *testing.T) { defer m.close() cases := []int{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} - expected := []uint64{0, 1, 2, 3, 4, 5, 5, 0, 3, 2, 1, 5} + expected := []uint64{1, 2, 3, 4, 5, 6, 6, 1, 4, 3, 2, 6} for i, c := range cases { m.getRowID("frame name", c) diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index_test.go index e1c2ad08b..9dd22b5da 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/index_test.go @@ -1,7 +1,6 @@ package sql import ( - "context" "crypto/sha1" "fmt" "testing" @@ -85,15 +84,20 @@ func TestDeleteIndex(t *testing.T) { r := NewIndexRegistry() idx := &dummyIdx{"foo", nil, "foo", "foo"} + idx2 := &dummyIdx{"foo", nil, "foo", "bar"} r.indexes[indexKey{"foo", "foo"}] = idx + r.indexes[indexKey{"foo", "bar"}] = idx2 - _, err := r.DeleteIndex("foo", "foo") + _, err := r.DeleteIndex("foo", "foo", false) require.Error(err) require.True(ErrIndexDeleteInvalidStatus.Is(err)) - r.setStatus(idx, IndexReady) + _, err = r.DeleteIndex("foo", "foo", true) + require.NoError(err) + + r.setStatus(idx2, IndexReady) - _, err = r.DeleteIndex("foo", "foo") + _, err = r.DeleteIndex("foo", "foo", false) require.NoError(err) require.Len(r.indexes, 0) @@ -109,7 +113,7 @@ func TestDeleteIndex_InUse(t *testing.T) { r.setStatus(idx, IndexReady) r.retainIndex("foo", "foo") - done, err := r.DeleteIndex("foo", "foo") + done, err := r.DeleteIndex("foo", "foo", false) require.NoError(err) require.Len(r.indexes, 1) @@ -270,8 +274,8 @@ func (d loadDriver) LoadAll(db, table string) ([]Index, error) { } return result, nil } -func (loadDriver) Save(ctx context.Context, index Index, iter IndexKeyValueIter) error { return nil } -func (loadDriver) Delete(index Index) error { return nil } +func (loadDriver) Save(ctx *Context, index Index, iter IndexKeyValueIter) error { return nil } +func (loadDriver) Delete(index Index) error { return nil } type dummyIdx struct { id string diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse.go index 1436547ac..f6cbb5e1a 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse.go @@ -114,13 +114,6 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { } } - if len(s.OrderBy) != 0 { - node, err = orderByToSort(s.OrderBy, node) - if err != nil { - return nil, err - } - } - node, err = selectToProjectOrGroupBy(s.SelectExprs, s.GroupBy, node) if err != nil { return nil, err @@ -130,6 +123,13 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) { node = plan.NewDistinct(node) } + if len(s.OrderBy) != 0 { + node, err = orderByToSort(s.OrderBy, node) + if err != nil { + return nil, err + } + } + if s.Limit != nil { node, err = limitToLimit(ctx, s.Limit.Rowcount, node) if err != nil { diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse_test.go index 962fa2856..54e2b62ec 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/parse/parse_test.go @@ -113,13 +113,13 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo"), ), ), - `SELECT foo, bar FROM foo ORDER BY baz DESC;`: plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + `SELECT foo, bar FROM foo ORDER BY baz DESC;`: plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, plan.NewUnresolvedTable("foo"), ), ), @@ -139,25 +139,25 @@ var fixtures = map[string]sql.Node{ ), ), `SELECT foo, bar FROM foo ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(1, - plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, plan.NewUnresolvedTable("foo"), ), ), ), `SELECT foo, bar FROM foo WHERE qux = 1 ORDER BY baz DESC LIMIT 1;`: plan.NewLimit(1, - plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("foo"), - expression.NewUnresolvedColumn("bar"), - }, - plan.NewSort( - []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewSort( + []plan.SortField{{Column: expression.NewUnresolvedColumn("baz"), Order: plan.Descending, NullOrdering: plan.NullsFirst}}, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("foo"), + expression.NewUnresolvedColumn("bar"), + }, plan.NewFilter( expression.NewEquals( expression.NewUnresolvedColumn("qux"), @@ -485,23 +485,23 @@ var fixtures = map[string]sql.Node{ plan.NewUnresolvedTable("foo"), ), ), - `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewProject( - []sql.Expression{ - expression.NewUnresolvedColumn("a"), - expression.NewUnresolvedColumn("b"), + `SELECT a, b FROM t ORDER BY 2, 1`: plan.NewSort( + []plan.SortField{ + { + Column: expression.NewLiteral(int64(2), sql.Int64), + Order: plan.Ascending, + NullOrdering: plan.NullsFirst, + }, + { + Column: expression.NewLiteral(int64(1), sql.Int64), + Order: plan.Ascending, + NullOrdering: plan.NullsFirst, + }, }, - plan.NewSort( - []plan.SortField{ - { - Column: expression.NewLiteral(int64(2), sql.Int64), - Order: plan.Ascending, - NullOrdering: plan.NullsFirst, - }, - { - Column: expression.NewLiteral(int64(1), sql.Int64), - Order: plan.Ascending, - NullOrdering: plan.NullsFirst, - }, + plan.NewProject( + []sql.Expression{ + expression.NewUnresolvedColumn("a"), + expression.NewUnresolvedColumn("b"), }, plan.NewUnresolvedTable("t"), ), diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index.go index 25085180b..217e6d859 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index.go @@ -3,7 +3,10 @@ package plan import ( "fmt" "strings" + "time" + opentracing "github.com/opentracing/opentracing-go" + otlog "github.com/opentracing/opentracing-go/log" "github.com/sirupsen/logrus" errors "gopkg.in/src-d/go-errors.v1" "gopkg.in/src-d/go-mysql-server.v0/sql" @@ -90,7 +93,7 @@ func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { return nil, ErrInvalidIndexDriver.New(c.Driver) } - columns, exprs, err := getColumnsAndPrepareExpressions(c.Exprs) + columns, exprs, exprHashes, err := getColumnsAndPrepareExpressions(c.Exprs) if err != nil { return nil, err } @@ -99,14 +102,14 @@ func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { c.CurrentDatabase, nameable.Name(), c.Name, - exprs, + exprHashes, c.Config, ) if err != nil { return nil, err } - iter, err := table.IndexKeyValueIter(ctx, columns) + iter, err := getIndexKeyValueIter(ctx, table, columns, exprs) if err != nil { return nil, err } @@ -121,27 +124,57 @@ func (c *CreateIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { "driver": index.Driver(), }) - go func() { - err := driver.Save(ctx, index, &loggingKeyValueIter{log: log, iter: iter}) - close(done) - if err != nil { - logrus.WithField("err", err).Error("unable to save the index") - deleted, err := c.Catalog.DeleteIndex(index.Database(), index.ID()) - if err != nil { - logrus.WithField("err", err).Error("unable to delete the index") - } else { - <-deleted - } - } else { - log.Info("index successfully created") - } - }() + go c.backgroundIndexCreate(ctx, log, driver, index, iter, done) log.Info("starting to save the index") return sql.RowsToRowIter(), nil } +func (c *CreateIndex) backgroundIndexCreate( + ctx *sql.Context, + log *logrus.Entry, + driver sql.IndexDriver, + index sql.Index, + iter sql.IndexKeyValueIter, + done chan<- struct{}, +) { + span, ctx := ctx.Span("plan.backgroundIndexCreate") + span.LogKV( + "index", index.ID(), + "table", index.Table(), + "driver", index.Driver(), + ) + + err := driver.Save(ctx, index, newLoggingKeyValueIter(span, log, iter)) + close(done) + + if err != nil { + span.FinishWithOptions(opentracing.FinishOptions{ + LogRecords: []opentracing.LogRecord{ + { + Timestamp: time.Now(), + Fields: []otlog.Field{ + otlog.String("error", err.Error()), + }, + }, + }, + }) + + logrus.WithField("err", err).Error("unable to save the index") + + deleted, err := c.Catalog.DeleteIndex(index.Database(), index.ID(), true) + if err != nil { + logrus.WithField("err", err).Error("unable to delete the index") + } else { + <-deleted + } + } else { + span.Finish() + log.Info("index successfully created") + } +} + // Schema implements the Node interface. func (c *CreateIndex) Schema() sql.Schema { return nil } @@ -223,10 +256,11 @@ func (c *CreateIndex) TransformUp(fn sql.TransformNodeFunc) (sql.Node, error) { // to match a row with only the returned columns in that same order. func getColumnsAndPrepareExpressions( exprs []sql.Expression, -) ([]string, []sql.ExpressionHash, error) { +) ([]string, []sql.Expression, []sql.ExpressionHash, error) { var columns []string var seen = make(map[string]int) - var expressions = make([]sql.ExpressionHash, len(exprs)) + var expressions = make([]sql.Expression, len(exprs)) + var expressionHashes = make([]sql.ExpressionHash, len(exprs)) for i, e := range exprs { ex, err := e.TransformUp(func(e sql.Expression) (sql.Expression, error) { @@ -253,29 +287,96 @@ func getColumnsAndPrepareExpressions( ), nil }) + if err != nil { + return nil, nil, nil, err + } + + expressions[i] = ex + expressionHashes[i] = sql.NewExpressionHash(ex) + } + + return columns, expressions, expressionHashes, nil +} + +type evalKeyValueIter struct { + ctx *sql.Context + iter sql.IndexKeyValueIter + exprs []sql.Expression +} + +func (eit *evalKeyValueIter) Next() ([]interface{}, []byte, error) { + vals, loc, err := eit.iter.Next() + if err != nil { + return nil, nil, err + } + row := sql.NewRow(vals...) + evals := make([]interface{}, len(eit.exprs)) + for i, ex := range eit.exprs { + eval, err := ex.Eval(eit.ctx, row) if err != nil { return nil, nil, err } - expressions[i] = sql.NewExpressionHash(ex) + evals[i] = eval } - return columns, expressions, nil + return evals, loc, nil +} + +func (eit *evalKeyValueIter) Close() error { + return eit.iter.Close() +} + +func getIndexKeyValueIter(ctx *sql.Context, table sql.Indexable, columns []string, exprs []sql.Expression) (*evalKeyValueIter, error) { + iter, err := table.IndexKeyValueIter(ctx, columns) + if err != nil { + return nil, err + } + + return &evalKeyValueIter{ctx, iter, exprs}, nil } type loggingKeyValueIter struct { - log *logrus.Entry - iter sql.IndexKeyValueIter - rows uint64 + span opentracing.Span + log *logrus.Entry + iter sql.IndexKeyValueIter + rows uint64 + start time.Time +} + +func newLoggingKeyValueIter( + span opentracing.Span, + log *logrus.Entry, + iter sql.IndexKeyValueIter, +) sql.IndexKeyValueIter { + return &loggingKeyValueIter{ + span: span, + log: log, + iter: iter, + start: time.Now(), + } } func (i *loggingKeyValueIter) Next() ([]interface{}, []byte, error) { i.rows++ if i.rows%100 == 0 { - i.log.Debugf("still creating index: %d rows saved so far", i.rows) + duration := time.Since(i.start) + + i.log.WithField("duration", duration). + Debugf("still creating index: %d rows saved so far", i.rows) + + i.span.LogFields( + otlog.String("event", "saved rows"), + otlog.Uint64("rows", i.rows), + otlog.String("duration", duration.String()), + ) + + i.start = time.Now() } return i.iter.Next() } -func (i *loggingKeyValueIter) Close() error { return i.iter.Close() } +func (i *loggingKeyValueIter) Close() error { + return i.iter.Close() +} diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index_test.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index_test.go index c6709ca04..2032aef54 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index_test.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/create_index_test.go @@ -2,14 +2,16 @@ package plan import ( "context" + "math" "testing" "time" + "gopkg.in/src-d/go-mysql-server.v0/mem" + "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-mysql-server.v0/sql/expression" + "gopkg.in/src-d/go-mysql-server.v0/test" "github.com/stretchr/testify/require" - "gopkg.in/src-d/go-mysql-server.v0/mem" - "gopkg.in/src-d/go-mysql-server.v0/sql" ) func TestCreateIndex(t *testing.T) { @@ -37,7 +39,9 @@ func TestCreateIndex(t *testing.T) { ci.Catalog = catalog ci.CurrentDatabase = "foo" - _, err := ci.RowIter(sql.NewEmptyContext()) + tracer := new(test.MemTracer) + ctx := sql.NewContext(context.Background(), sql.WithTracer(tracer)) + _, err := ci.RowIter(ctx) require.NoError(err) time.Sleep(50 * time.Millisecond) @@ -50,6 +54,70 @@ func TestCreateIndex(t *testing.T) { sql.NewExpressionHash(expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", true)), sql.NewExpressionHash(expression.NewGetFieldWithTable(1, sql.Int64, "foo", "a", true)), }}, idx) + + found := false + for _, span := range tracer.Spans { + if span == "plan.backgroundIndexCreate" { + found = true + break + } + } + + require.True(found) +} + +func TestCreateIndexWithIter(t *testing.T) { + require := require.New(t) + foo := mem.NewTable("foo", sql.Schema{ + {Name: "one", Source: "foo", Type: sql.Int64}, + {Name: "two", Source: "foo", Type: sql.Int64}, + }) + + rows := [][2]int64{ + {1, 2}, + {-1, -2}, + {0, 0}, + {math.MaxInt64, math.MinInt64}, + } + for _, r := range rows { + err := foo.Insert(sql.NewRow(r[0], r[1])) + require.NoError(err) + } + + table := &indexableTable{foo} + exprs := []sql.Expression{expression.NewPlus( + expression.NewGetField(0, sql.Int64, "one", false), + expression.NewGetField(0, sql.Int64, "two", false)), + } + + driver := new(mockDriver) + catalog := sql.NewCatalog() + catalog.RegisterIndexDriver(driver) + db := mem.NewDatabase("foo") + db.AddTable("foo", table) + catalog.Databases = append(catalog.Databases, db) + + ci := NewCreateIndex("idx", table, exprs, "mock", make(map[string]string)) + ci.Catalog = catalog + ci.CurrentDatabase = "foo" + + columns, exprs, _, err := getColumnsAndPrepareExpressions(ci.Exprs) + require.NoError(err) + + iter, err := getIndexKeyValueIter(sql.NewEmptyContext(), table, columns, exprs) + require.NoError(err) + + var ( + vals []interface{} + ) + for i := 0; err == nil; i++ { + vals, _, err = iter.Next() + if err == nil { + require.Equal(1, len(vals)) + require.Equal(rows[i][0]+rows[i][1], vals[0]) + } + } + require.NoError(iter.Close()) } type mockIndex struct { @@ -87,7 +155,7 @@ func (*mockDriver) Create(db, table, id string, exprs []sql.ExpressionHash, conf func (*mockDriver) LoadAll(db, table string) ([]sql.Index, error) { panic("not implemented") } -func (d *mockDriver) Save(ctx context.Context, index sql.Index, iter sql.IndexKeyValueIter) error { +func (d *mockDriver) Save(ctx *sql.Context, index sql.Index, iter sql.IndexKeyValueIter) error { d.saved = append(d.saved, index.ID()) return nil } @@ -106,8 +174,13 @@ func (indexableTable) HandledFilters([]sql.Expression) []sql.Expression { panic("not implemented") } -func (indexableTable) IndexKeyValueIter(_ *sql.Context, colNames []string) (sql.IndexKeyValueIter, error) { - return nil, nil +func (it *indexableTable) IndexKeyValueIter(ctx *sql.Context, colNames []string) (sql.IndexKeyValueIter, error) { + t, ok := it.Table.(*mem.Table) + if !ok { + return nil, nil + } + + return t.IndexKeyValueIter(ctx, colNames) } func (indexableTable) WithProjectAndFilters(ctx *sql.Context, columns, filters []sql.Expression) (sql.RowIter, error) { diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/drop_index.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/drop_index.go index c8bed73be..e31b97f5b 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/drop_index.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/drop_index.go @@ -52,7 +52,7 @@ func (d *DropIndex) RowIter(ctx *sql.Context) (sql.RowIter, error) { } d.Catalog.ReleaseIndex(index) - done, err := d.Catalog.DeleteIndex(db.Name(), d.Name) + done, err := d.Catalog.DeleteIndex(db.Name(), d.Name, false) if err != nil { return nil, err } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/group_by.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/group_by.go index 30ecfccd1..77aee2c76 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/group_by.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/group_by.go @@ -53,10 +53,16 @@ func (p *GroupBy) Schema() sql.Schema { name = e.String() } + var table string + if t, ok := e.(sql.Tableable); ok { + table = t.Table() + } + s[i] = &sql.Column{ Name: name, Type: e.Type(), Nullable: e.IsNullable(), + Source: table, } } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/sort.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/sort.go index b7923172b..f33f8a837 100644 --- a/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/sort.go +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/sql/plan/sort.go @@ -215,7 +215,7 @@ func (i *sortIter) computeSortedRows() error { rows: rows, lastError: nil, } - sort.Sort(sorter) + sort.Stable(sorter) if sorter.lastError != nil { return sorter.lastError } diff --git a/vendor/gopkg.in/src-d/go-mysql-server.v0/test/mem_tracer.go b/vendor/gopkg.in/src-d/go-mysql-server.v0/test/mem_tracer.go new file mode 100644 index 000000000..f11e00542 --- /dev/null +++ b/vendor/gopkg.in/src-d/go-mysql-server.v0/test/mem_tracer.go @@ -0,0 +1,48 @@ +package test + +import ( + opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/log" +) + +// MemTracer implements a simple tracer in memory for testing. +type MemTracer struct { + Spans []string +} + +type memSpan struct { + opName string +} + +// StartSpan implements opentracing.Tracer interface. +func (t *MemTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span { + t.Spans = append(t.Spans, operationName) + return &memSpan{operationName} +} + +// Inject implements opentracing.Tracer interface. +func (t *MemTracer) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error { + panic("not implemented") +} + +// Extract implements opentracing.Tracer interface. +func (t *MemTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) { + panic("not implemented") +} + +func (m memSpan) Context() opentracing.SpanContext { return m } +func (m memSpan) SetBaggageItem(key, val string) opentracing.Span { return m } +func (m memSpan) BaggageItem(key string) string { return "" } +func (m memSpan) SetTag(key string, value interface{}) opentracing.Span { return m } +func (m memSpan) LogFields(fields ...log.Field) {} +func (m memSpan) LogKV(keyVals ...interface{}) {} +func (m memSpan) Finish() {} +func (m memSpan) FinishWithOptions(opts opentracing.FinishOptions) {} +func (m memSpan) SetOperationName(operationName string) opentracing.Span { + return &memSpan{operationName} +} +func (m memSpan) Tracer() opentracing.Tracer { return &MemTracer{} } +func (m memSpan) LogEvent(event string) {} +func (m memSpan) LogEventWithPayload(event string, payload interface{}) {} +func (m memSpan) Log(data opentracing.LogData) {} +func (m memSpan) ForeachBaggageItem(handler func(k, v string) bool) {}