Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: warehouse transformer #5205

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/databricks/databricks-sql-go v1.6.1
github.com/denisenkom/go-mssqldb v0.12.3
github.com/dgraph-io/badger/v4 v4.5.0
github.com/dlclark/regexp2 v1.11.4
github.com/docker/docker v27.5.0+incompatible
github.com/go-chi/chi/v5 v5.2.0
github.com/go-redis/redis v6.15.9+incompatible
Expand Down Expand Up @@ -193,7 +194,6 @@ require (
github.com/danieljoos/wincred v1.2.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/dnephin/pflag v1.0.7 // indirect
github.com/docker/cli v27.2.1+incompatible // indirect
github.com/docker/cli-docs-tool v0.8.0 // indirect
Expand Down
97 changes: 93 additions & 4 deletions processor/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
"encoding/json"
"errors"
"fmt"
"reflect"
"runtime/trace"
"slices"
"strconv"
"strings"
"sync"
"time"

obskit "github.com/rudderlabs/rudder-observability-kit/go/labels"

"github.com/google/uuid"

"github.com/rudderlabs/rudder-server/enterprise/trackedusers"
"github.com/rudderlabs/rudder-server/utils/timeutil"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"

"golang.org/x/sync/errgroup"

Expand Down Expand Up @@ -57,6 +62,7 @@
. "github.com/rudderlabs/rudder-server/utils/tx" //nolint:staticcheck
"github.com/rudderlabs/rudder-server/utils/types"
"github.com/rudderlabs/rudder-server/utils/workerpool"
wtrans "github.com/rudderlabs/rudder-server/warehouse/transformer"
)

const (
Expand Down Expand Up @@ -84,12 +90,18 @@
GenerateReportsFromJobs(jobs []*jobsdb.JobT, sourceIdFilter map[string]bool) []*trackedusers.UsersReport
}

type warehouseTransformation interface {
transformer.DestinationTransformer
Log(events []types.SingularEventT, metadata *transformer.Metadata) error
}

// Handle is a handle to the processor module
type Handle struct {
conf *config.Config
tracer stats.Tracer
backendConfig backendconfig.BackendConfig
transformer transformer.Transformer
conf *config.Config
tracer stats.Tracer
backendConfig backendconfig.BackendConfig
transformer transformer.Transformer
warehouseTransformer warehouseTransformation

gatewayDB jobsdb.JobsDB
routerDB jobsdb.JobsDB
Expand Down Expand Up @@ -159,6 +171,13 @@
eventAuditEnabled map[string]bool
credentialsMap map[string][]transformer.Credential
nonEventStreamSources map[string]bool
enableWarehouseTransformations config.ValueLoader[bool]
}

warehouseTransformerStats struct {
responseTime stats.Timer
mismatches stats.Counter
logTime stats.Timer
}

drainConfig struct {
Expand Down Expand Up @@ -618,6 +637,12 @@
"partition": partition,
})
}

proc.warehouseTransformer = wtrans.New(proc.conf, proc.logger, proc.statsFactory)
proc.warehouseTransformerStats.responseTime = proc.statsFactory.NewStat("proc_warehouse_transformations_time", stats.TimerType)
proc.warehouseTransformerStats.mismatches = proc.statsFactory.NewStat("proc_warehouse_transformations_mismatches", stats.CountType)
proc.warehouseTransformerStats.logTime = proc.statsFactory.NewStat("proc_warehouse_transformations_log_time", stats.TimerType)

if proc.config.enableDedup {
var err error
proc.dedup, err = dedup.New(proc.conf, proc.statsFactory)
Expand Down Expand Up @@ -819,6 +844,7 @@
proc.config.archivalEnabled = config.GetReloadableBoolVar(true, "archival.Enabled")
// Capture event name as a tag in event level stats
proc.config.captureEventNameStats = config.GetReloadableBoolVar(false, "Processor.Stats.captureEventName")
proc.config.enableWarehouseTransformations = config.GetReloadableBoolVar(false, "Processor.enableWarehouseTransformations")
}

type connection struct {
Expand Down Expand Up @@ -3215,6 +3241,7 @@
proc.logger.Debug("Dest Transform input size", len(eventsToTransform))
s := time.Now()
response = proc.transformer.Transform(ctx, eventsToTransform, proc.config.transformBatchSize.Load())
proc.handleResponseForWarehouseTransformation(ctx, eventsToTransform, response, commonMetaData, eventsByMessageID)
achettyiitr marked this conversation as resolved.
Show resolved Hide resolved

destTransformationStat := proc.newDestinationTransformationStat(sourceID, workspaceID, transformAt, destination)
destTransformationStat.transformTime.Since(s)
Expand Down Expand Up @@ -3373,6 +3400,68 @@
}
}

func (proc *Handle) handleResponseForWarehouseTransformation(
ctx context.Context,
eventsToTransform []transformer.TransformerEvent,
pResponse transformer.Response,
commonMetaData *transformer.Metadata,
eventsByMessageID map[string]types.SingularEventWithReceivedAt,
) {
if _, ok := whutils.WarehouseDestinationMap[commonMetaData.DestinationType]; !ok {
return
}
if len(eventsToTransform) == 0 || !proc.config.enableWarehouseTransformations.Load() {
return
}

transformStartAt := timeutil.Now()
wResponse := proc.warehouseTransformer.Transform(ctx, eventsToTransform, proc.config.transformBatchSize.Load())
proc.warehouseTransformerStats.responseTime.Since(transformStartAt)

logStartAt := timeutil.Now()
differingEvents := proc.warehouseTransDifferEvents(eventsToTransform, pResponse, wResponse, eventsByMessageID)
if err := proc.warehouseTransformer.Log(differingEvents, commonMetaData); err != nil {
proc.logger.Warnn("Failed to log events for warehouse transformation debugging", obskit.Error(err))
}
proc.warehouseTransformerStats.logTime.Since(logStartAt)

Check warning on line 3426 in processor/processor.go

View check run for this annotation

Codecov / codecov/patch

processor/processor.go#L3417-L3426

Added lines #L3417 - L3426 were not covered by tests
}

func (proc *Handle) warehouseTransDifferEvents(
eventsToTransform []transformer.TransformerEvent,
pResponse, wResponse transformer.Response,
eventsByMessageID map[string]types.SingularEventWithReceivedAt,
) []types.SingularEventT {
// If the event counts differ, return all events in the transformation
if len(pResponse.Events) != len(wResponse.Events) || len(pResponse.FailedEvents) != len(wResponse.FailedEvents) {
events := lo.Map(eventsToTransform, func(e transformer.TransformerEvent, _ int) types.SingularEventT {
return eventsByMessageID[e.Metadata.MessageID].SingularEvent
})
proc.warehouseTransformerStats.mismatches.Count(len(events))
return events

Check warning on line 3440 in processor/processor.go

View check run for this annotation

Codecov / codecov/patch

processor/processor.go#L3433-L3440

Added lines #L3433 - L3440 were not covered by tests
}

var (
differedSampleEvents []types.SingularEventT
differedEventsCount int
)

for i := range pResponse.Events {
if reflect.DeepEqual(pResponse.Events[i], wResponse.Events[i]) {
continue

Check warning on line 3450 in processor/processor.go

View check run for this annotation

Codecov / codecov/patch

processor/processor.go#L3443-L3450

Added lines #L3443 - L3450 were not covered by tests
}

differedEventsCount++
if len(differedSampleEvents) != 0 {
// Collect the mismatched messages and break (sample only)
differedSampleEvents = append(differedSampleEvents, lo.Map(pResponse.Events[i].Metadata.GetMessagesIDs(), func(msgID string, _ int) types.SingularEventT {
return eventsByMessageID[msgID].SingularEvent
})...)

Check warning on line 3458 in processor/processor.go

View check run for this annotation

Codecov / codecov/patch

processor/processor.go#L3453-L3458

Added lines #L3453 - L3458 were not covered by tests
}
}
proc.warehouseTransformerStats.mismatches.Count(differedEventsCount)
return differedSampleEvents

Check warning on line 3462 in processor/processor.go

View check run for this annotation

Codecov / codecov/patch

processor/processor.go#L3461-L3462

Added lines #L3461 - L3462 were not covered by tests
}

func (proc *Handle) saveDroppedJobs(ctx context.Context, droppedJobs []*jobsdb.JobT, tx *Tx) error {
if len(droppedJobs) > 0 {
for i := range droppedJobs { // each dropped job should have a unique jobID in the scope of the batch
Expand Down
22 changes: 17 additions & 5 deletions processor/transformer/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,25 @@ func WithClient(client HTTPDoer) Opt {
}
}

// Transformer provides methods to transform events
type Transformer interface {
Transform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
type UserTransformer interface {
UserTransform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

type DestinationTransformer interface {
Transform(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

type TrackingPlanValidator interface {
Validate(ctx context.Context, clientEvents []TransformerEvent, batchSize int) Response
}

// Transformer provides methods to transform events
type Transformer interface {
UserTransformer
DestinationTransformer
TrackingPlanValidator
}

type HTTPDoer interface {
Do(req *http.Request) (*http.Response, error)
}
Expand Down Expand Up @@ -591,7 +603,7 @@ func (trans *handle) destTransformURL(destType string) string {
destinationEndPoint := fmt.Sprintf("%s/v0/destinations/%s", trans.config.destTransformationURL, strings.ToLower(destType))

if _, ok := warehouseutils.WarehouseDestinationMap[destType]; ok {
whSchemaVersionQueryParam := fmt.Sprintf("whSchemaVersion=%s&whIDResolve=%v", trans.conf.GetString("Warehouse.schemaVersion", "v1"), warehouseutils.IDResolutionEnabled())
whSchemaVersionQueryParam := fmt.Sprintf("whIDResolve=%t", trans.conf.GetBool("Warehouse.enableIDResolution", false))
switch destType {
case warehouseutils.RS:
return destinationEndPoint + "?" + whSchemaVersionQueryParam
Expand All @@ -603,7 +615,7 @@ func (trans *handle) destTransformURL(destType string) string {
}
}
if destType == warehouseutils.SnowpipeStreaming {
return fmt.Sprintf("%s?whSchemaVersion=%s&whIDResolve=%t", destinationEndPoint, trans.conf.GetString("Warehouse.schemaVersion", "v1"), warehouseutils.IDResolutionEnabled())
return fmt.Sprintf("%s?whIDResolve=%t", destinationEndPoint, trans.conf.GetBool("Warehouse.enableIDResolution", false))
}
return destinationEndPoint
}
Expand Down
2 changes: 1 addition & 1 deletion warehouse/internal/model/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const (
JSONDataType SchemaType = "json"
TextDataType SchemaType = "text"
DateTimeDataType SchemaType = "datetime"
ArrayOfBooleanDatatype SchemaType = "array(boolean)"
ArrayOfBooleanDataType SchemaType = "array(boolean)"
)

type WHSchema struct {
Expand Down
2 changes: 1 addition & 1 deletion warehouse/slave/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (w *worker) processStagingFile(ctx context.Context, job payload) ([]uploadR
}

columnVal = newColumnVal
case model.ArrayOfBooleanDatatype:
case model.ArrayOfBooleanDataType:
if boolValue, ok := columnVal.([]interface{}); ok {
newColumnVal := make([]interface{}, len(boolValue))

Expand Down
89 changes: 89 additions & 0 deletions warehouse/transformer/datatype.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package transformer

import (
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
"github.com/rudderlabs/rudder-server/warehouse/transformer/internal/utils"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

func dataTypeFor(destType, key string, val any, isJSONKey bool) string {
if typeName := primitiveType(val); typeName != "" {
return typeName
}
if strVal, ok := val.(string); ok && utils.ValidTimestamp(strVal) {
return model.DateTimeDataType
}
if override := dataTypeOverride(destType, key, val, isJSONKey); override != "" {
return override
}
return model.StringDataType
}

func primitiveType(val any) string {
switch v := val.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return model.IntDataType
case float64:
return getFloatType(v)
case float32:
return getFloatType(float64(v))

Check warning on line 29 in warehouse/transformer/datatype.go

View check run for this annotation

Codecov / codecov/patch

warehouse/transformer/datatype.go#L28-L29

Added lines #L28 - L29 were not covered by tests
case bool:
return model.BooleanDataType
default:
return ""
}
}

func getFloatType(v float64) string {
if v == float64(int64(v)) {
return model.IntDataType
}
return model.FloatDataType
}

func dataTypeOverride(destType, key string, val any, isJSONKey bool) string {
switch destType {
case whutils.POSTGRES:
return overrideForPostgres(key, isJSONKey)
case whutils.SNOWFLAKE, whutils.SnowpipeStreaming:
return overrideForSnowflake(key, isJSONKey)
case whutils.RS:
return overrideForRedshift(val, isJSONKey)
default:
return ""
}
}

func overrideForPostgres(key string, isJSONKey bool) string {
if key == violationErrors || isJSONKey {
return model.JSONDataType
}
return model.StringDataType
}

func overrideForSnowflake(key string, isJSONKey bool) string {
if key == violationErrors || isJSONKey {
return model.JSONDataType
}
return model.StringDataType

Check warning on line 68 in warehouse/transformer/datatype.go

View check run for this annotation

Codecov / codecov/patch

warehouse/transformer/datatype.go#L68

Added line #L68 was not covered by tests
}

func overrideForRedshift(val any, isJSONKey bool) string {
if isJSONKey {
return model.JSONDataType
}
if val == nil {
return model.StringDataType
}
if jsonVal, _ := json.Marshal(val); len(jsonVal) > redshiftStringLimit {
return model.TextDataType
}
return model.StringDataType
}

func convertValIfDateTime(val any, colType string) any {
if colType == model.DateTimeDataType {
return utils.ToTimestamp(val)
}
return val
}
55 changes: 55 additions & 0 deletions warehouse/transformer/datatype_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package transformer

import (
"testing"

"github.com/stretchr/testify/require"

whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

func TestGetDataType(t *testing.T) {
testCases := []struct {
name, destType, key string
val any
isJSONKey bool
expected string
}{
// Primitive types
{"Primitive Type Int", whutils.POSTGRES, "someKey", 42, false, "int"},
{"Primitive Type Float", whutils.POSTGRES, "someKey", 42.0, false, "int"},
{"Primitive Type Float (non-int)", whutils.POSTGRES, "someKey", 42.5, false, "float"},
{"Primitive Type Bool", whutils.POSTGRES, "someKey", true, false, "boolean"},

// Valid timestamp
{"Valid Timestamp String", whutils.POSTGRES, "someKey", "2022-10-05T14:48:00.000Z", false, "datetime"},

// JSON Key cases for different destinations
{"Postgres JSON Key", whutils.POSTGRES, "someKey", "someValue", true, "json"},
{"Snowflake JSON Key", whutils.SNOWFLAKE, "someKey", "someValue", true, "json"},
{"Redshift JSON Key", whutils.RS, "someKey", "someValue", true, "json"},

// Redshift with text and string types
{"Redshift Text Type", whutils.RS, "someKey", string(make([]byte, 513)), false, "text"},
{"Redshift String Type", whutils.RS, "someKey", "shortValue", false, "string"},
{"Redshift String Type", whutils.RS, "someKey", nil, false, "string"},

// Empty string values
{"Empty String Value", whutils.POSTGRES, "someKey", "", false, "string"},
{"Empty String with JSON Key", whutils.POSTGRES, "someKey", "", true, "json"},

// Unsupported types (should default to string)
{"Unsupported Type Struct", whutils.POSTGRES, "someKey", struct{}{}, false, "string"},
{"Unsupported Type Map", whutils.POSTGRES, "someKey", map[string]any{"key": "value"}, false, "string"},

// Special string values
{"Special Timestamp-like String", whutils.POSTGRES, "someKey", "not-a-timestamp", false, "string"},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := dataTypeFor(tc.destType, tc.key, tc.val, tc.isJSONKey)
require.Equal(t, tc.expected, actual)
})
}
}
Loading
Loading