diff --git a/warehouse/integrations/bigquery/bigquery.go b/warehouse/integrations/bigquery/bigquery.go index b75c566610..f34be2ec10 100644 --- a/warehouse/integrations/bigquery/bigquery.go +++ b/warehouse/integrations/bigquery/bigquery.go @@ -49,6 +49,7 @@ type BigQuery struct { enableDeleteByJobs bool customPartitionsEnabledWorkspaceIDs []string slowQueryThreshold time.Duration + loadByFolderPath bool } } @@ -61,7 +62,7 @@ const ( tableNameLimit = 127 ) -// maps datatype stored in rudder to datatype in bigquery +// dataTypesMap maps datatype stored in rudder to datatype in bigquery var dataTypesMap = map[string]bigquery.FieldType{ "boolean": bigquery.BooleanFieldType, "int": bigquery.IntegerFieldType, @@ -70,7 +71,7 @@ var dataTypesMap = map[string]bigquery.FieldType{ "datetime": bigquery.TimestampFieldType, } -// maps datatype in bigquery to datatype stored in rudder +// dataTypesMapToRudder maps datatype in bigquery to datatype stored in rudder var dataTypesMapToRudder = map[bigquery.FieldType]string{ "BOOLEAN": "boolean", "BOOL": "boolean", @@ -130,6 +131,7 @@ func New(conf *config.Config, log logger.Logger) *BigQuery { bq.config.enableDeleteByJobs = conf.GetBool("Warehouse.bigquery.enableDeleteByJobs", false) bq.config.customPartitionsEnabledWorkspaceIDs = conf.GetStringSlice("Warehouse.bigquery.customPartitionsEnabledWorkspaceIDs", nil) bq.config.slowQueryThreshold = conf.GetDuration("Warehouse.bigquery.slowQueryThreshold", 5, time.Minute) + bq.config.loadByFolderPath = conf.GetBool("Warehouse.bigquery.loadByFolderPath", false) return bq } @@ -140,10 +142,8 @@ func getTableSchema(tableSchema model.TableSchema) []*bigquery.FieldSchema { }) } -func (bq *BigQuery) DeleteTable(ctx context.Context, tableName string) (err error) { - tableRef := bq.db.Dataset(bq.namespace).Table(tableName) - err = tableRef.Delete(ctx) - return +func (bq *BigQuery) DeleteTable(ctx context.Context, tableName string) error { + return bq.db.Dataset(bq.namespace).Table(tableName).Delete(ctx) } // CreateTable creates a table in BigQuery with the provided schema @@ -218,13 +218,6 @@ func (bq *BigQuery) CreateTable(ctx context.Context, tableName string, columnMap return nil } -func (bq *BigQuery) DropTable(ctx context.Context, tableName string) error { - if err := bq.DeleteTable(ctx, tableName); err != nil { - return err - } - return bq.DeleteTable(ctx, tableName+"_view") -} - // createTableView creates a view for the table to deduplicate the data // If custom partition is enabled, it creates a view with the partition column and type. Otherwise, it creates a view with ingestion-time partitioning func (bq *BigQuery) createTableView(ctx context.Context, tableName string, columnMap model.TableSchema) error { @@ -303,6 +296,13 @@ func (bq *BigQuery) createTableView(ctx context.Context, tableName string, colum return bq.db.Dataset(bq.namespace).Table(tableName+"_view").Create(ctx, metaData) } +func (bq *BigQuery) DropTable(ctx context.Context, tableName string) error { + if err := bq.DeleteTable(ctx, tableName); err != nil { + return err + } + return bq.DeleteTable(ctx, tableName+"_view") +} + func (bq *BigQuery) schemaExists(ctx context.Context, _, _ string) (exists bool, err error) { ds := bq.db.Dataset(bq.namespace) _, err = ds.Metadata(ctx) @@ -443,15 +443,12 @@ func (bq *BigQuery) loadTable(ctx context.Context, tableName string) ( ) log.Infon("started loading") - loadFileLocations, err := bq.loadFileLocations(ctx, tableName) + gcsReferences, err := bq.gcsReferences(ctx, tableName) if err != nil { - return nil, nil, fmt.Errorf("getting load file locations: %w", err) + return nil, nil, fmt.Errorf("getting gcs references: %w", err) } - gcsRef := bigquery.NewGCSReference(warehouseutils.GetGCSLocations( - loadFileLocations, - warehouseutils.GCSLocationOptions{}, - )...) + gcsRef := bigquery.NewGCSReference(gcsReferences...) gcsRef.SourceFormat = bigquery.JSON gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false @@ -459,10 +456,10 @@ func (bq *BigQuery) loadTable(ctx context.Context, tableName string) ( return bq.loadTableByAppend(ctx, tableName, gcsRef, log) } -func (bq *BigQuery) loadFileLocations( +func (bq *BigQuery) gcsReferences( ctx context.Context, tableName string, -) ([]warehouseutils.LoadFile, error) { +) ([]string, error) { switch tableName { case warehouseutils.IdentityMappingsTable, warehouseutils.IdentityMergeRulesTable: loadfile, err := bq.uploader.GetSingleLoadFile( @@ -472,12 +469,31 @@ func (bq *BigQuery) loadFileLocations( if err != nil { return nil, fmt.Errorf("getting single load file for table %s: %w", tableName, err) } - return []warehouseutils.LoadFile{loadfile}, nil + + locations := warehouseutils.GetGCSLocations([]warehouseutils.LoadFile{loadfile}, warehouseutils.GCSLocationOptions{}) + return locations, nil default: - return bq.uploader.GetLoadFilesMetadata( - ctx, - warehouseutils.GetLoadFilesOptions{Table: tableName}, - ) + if bq.config.loadByFolderPath { + objectLocation, err := bq.uploader.GetSampleLoadFileLocation(ctx, tableName) + if err != nil { + return nil, fmt.Errorf("getting sample load file location for table %s: %w", tableName, err) + } + gcsLocation := warehouseutils.GetGCSLocation(objectLocation, warehouseutils.GCSLocationOptions{}) + gcsLocationFolder := loadFolder(gcsLocation) + + return []string{gcsLocationFolder}, nil + } else { + loadFilesMetadata, err := bq.uploader.GetLoadFilesMetadata( + ctx, + warehouseutils.GetLoadFilesOptions{Table: tableName}, + ) + if err != nil { + return nil, fmt.Errorf("getting load files metadata for table %s: %w", tableName, err) + } + + locations := warehouseutils.GetGCSLocations(loadFilesMetadata, warehouseutils.GCSLocationOptions{}) + return locations, nil + } } } @@ -692,15 +708,12 @@ func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]err } func (bq *BigQuery) createAndLoadStagingUsersTable(ctx context.Context, stagingTable string) error { - loadFileLocations, err := bq.loadFileLocations(ctx, warehouseutils.UsersTable) + gcsReferences, err := bq.gcsReferences(ctx, warehouseutils.UsersTable) if err != nil { - return fmt.Errorf("getting load file locations: %w", err) + return fmt.Errorf("getting gcs references: %w", err) } - gcsRef := bigquery.NewGCSReference(warehouseutils.GetGCSLocations( - loadFileLocations, - warehouseutils.GCSLocationOptions{}, - )...) + gcsRef := bigquery.NewGCSReference(gcsReferences...) gcsRef.SourceFormat = bigquery.JSON gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false @@ -1178,9 +1191,16 @@ func (bq *BigQuery) Connect(ctx context.Context, warehouse model.Warehouse) (cli } func (bq *BigQuery) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, _ string) error { - gcsLocations := warehouseutils.GetGCSLocation(location, warehouseutils.GCSLocationOptions{}) + gcsLocation := warehouseutils.GetGCSLocation(location, warehouseutils.GCSLocationOptions{}) + + var gcsReference string + if bq.config.loadByFolderPath { + gcsReference = loadFolder(gcsLocation) + } else { + gcsReference = gcsLocation + } - gcsRef := bigquery.NewGCSReference([]string{gcsLocations}...) + gcsRef := bigquery.NewGCSReference(gcsReference) gcsRef.SourceFormat = bigquery.JSON gcsRef.MaxBadRecords = 0 gcsRef.IgnoreUnknownValues = false @@ -1213,6 +1233,10 @@ func (bq *BigQuery) LoadTestTable(ctx context.Context, location, tableName strin return nil } +func loadFolder(objectLocation string) string { + return warehouseutils.GetLocationFolder(objectLocation) + "/*" +} + func (*BigQuery) SetConnectionTimeout(_ time.Duration) { } diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index adb6364828..d2edb42581 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -4,12 +4,15 @@ import ( "context" "errors" "fmt" + "io" "os" "slices" + "strconv" "testing" "time" "cloud.google.com/go/bigquery" + "github.com/google/uuid" "github.com/samber/lo" "go.uber.org/mock/gomock" "google.golang.org/api/iterator" @@ -1095,6 +1098,90 @@ func TestIntegration(t *testing.T) { ) require.Equal(t, records, whth.SampleTestRecords()) }) + t.Run("multiple files", func(t *testing.T) { + testCases := []struct { + name string + loadByFolderPath bool + }{ + {name: "loadByFolderPath = false", loadByFolderPath: false}, + {name: "loadByFolderPath = true", loadByFolderPath: true}, + } + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tableName := "multiple_files_test_table" + strconv.Itoa(i) + repeat := 10 + loadObjectFolder := "rudder-warehouse-load-objects" + sourceID := "test_source_id" + + prefixes := []string{loadObjectFolder, tableName, sourceID, uuid.New().String() + "-" + tableName} + + loadFiles := lo.RepeatBy(repeat, func(int) whutils.LoadFile { + sourceFile, err := os.Open("../testdata/load.json.gz") + require.NoError(t, err) + defer func() { _ = sourceFile.Close() }() + + tempFile, err := os.CreateTemp("", "clone_*.json.gz") + require.NoError(t, err) + defer func() { _ = tempFile.Close() }() + + _, err = io.Copy(tempFile, sourceFile) + require.NoError(t, err) + + f, err := os.Open(tempFile.Name()) + require.NoError(t, err) + defer func() { _ = f.Close() }() + + uploadOutput, err := fm.Upload(context.Background(), f, prefixes...) + require.NoError(t, err) + return whutils.LoadFile{Location: uploadOutput.Location} + }) + mockUploader := newMockUploader(t, loadFiles, tableName, schemaInUpload, schemaInWarehouse) + if tc.loadByFolderPath { + mockUploader.EXPECT().GetSampleLoadFileLocation(gomock.Any(), tableName).Return(loadFiles[0].Location, nil).Times(1) + } else { + mockUploader.EXPECT().GetSampleLoadFileLocation(gomock.Any(), tableName).Times(0) + } + + c := config.New() + c.Set("Warehouse.bigquery.loadByFolderPath", tc.loadByFolderPath) + + bq := whbigquery.New(c, logger.NOP) + require.NoError(t, bq.Setup(ctx, warehouse, mockUploader)) + require.NoError(t, bq.CreateSchema(ctx)) + require.NoError(t, bq.CreateTable(ctx, tableName, schemaInWarehouse)) + + loadTableStat, err := bq.LoadTable(ctx, tableName) + require.NoError(t, err) + require.Equal(t, loadTableStat.RowsInserted, int64(repeat*14)) + require.Equal(t, loadTableStat.RowsUpdated, int64(0)) + + records := bqhelper.RetrieveRecordsFromWarehouse(t, db, + fmt.Sprintf(` + SELECT + id, + received_at, + test_bool, + test_datetime, + test_float, + test_int, + test_string + FROM %s.%s + WHERE _PARTITIONTIME BETWEEN TIMESTAMP('%s') AND TIMESTAMP('%s') + ORDER BY id;`, + namespace, + tableName, + time.Now().Add(-24*time.Hour).Format("2006-01-02"), + time.Now().Add(+24*time.Hour).Format("2006-01-02"), + ), + ) + expectedRecords := make([][]string, 0, repeat) + for i := 0; i < repeat; i++ { + expectedRecords = append(expectedRecords, whth.SampleTestRecords()...) + } + require.ElementsMatch(t, expectedRecords, records) + }) + } + }) }) t.Run("Fetch schema", func(t *testing.T) { @@ -1457,7 +1544,7 @@ func newMockUploader( tableName string, schemaInUpload model.TableSchema, schemaInWarehouse model.TableSchema, -) whutils.Uploader { +) *mockuploader.MockUploader { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index 3d26559647..848c2c481f 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -95,7 +95,6 @@ const ( const ( WAREHOUSE = "warehouse" RudderMissingDatatype = "warehouse_rudder_missing_datatype" - MissingDatatype = "" ) const ( @@ -290,8 +289,7 @@ func GetObjectFolderForDeltalake(provider, location string) (folder string) { blobUrlParts := azblob.NewBlobURLParts(*blobUrl) accountName := strings.Replace(blobUrlParts.Host, ".blob.core.windows.net", "", 1) blobLocation := fmt.Sprintf("wasbs://%s@%s.blob.core.windows.net/%s", blobUrlParts.ContainerName, accountName, blobUrlParts.BlobName) - lastPos := strings.LastIndex(blobLocation, "/") - folder = blobLocation[:lastPos] + folder = GetLocationFolder(blobLocation) } return } @@ -389,8 +387,7 @@ func GetS3Location(location string) (s3Location, region string) { // https://test-bucket.s3.amazonaws.com/myfolder/test-object.csv --> s3://test-bucket/myfolder func GetS3LocationFolder(location string) string { s3Location, _ := GetS3Location(location) - lastPos := strings.LastIndex(s3Location, "/") - return s3Location[:lastPos] + return GetLocationFolder(s3Location) } type GCSLocationOptions struct { @@ -413,9 +410,7 @@ func GetGCSLocation(location string, options GCSLocationOptions) string { // GetGCSLocationFolder returns the folder path for a gcs object // https://storage.googleapis.com/test-bucket/myfolder/test-object.csv --> gcs://test-bucket/myfolder func GetGCSLocationFolder(location string, options GCSLocationOptions) string { - s3Location := GetGCSLocation(location, options) - lastPos := strings.LastIndex(s3Location, "/") - return s3Location[:lastPos] + return GetLocationFolder(GetGCSLocation(location, options)) } func GetGCSLocations(loadFiles []LoadFile, options GCSLocationOptions) (gcsLocations []string) { @@ -425,6 +420,10 @@ func GetGCSLocations(loadFiles []LoadFile, options GCSLocationOptions) (gcsLocat return } +func GetLocationFolder(location string) string { + return location[:strings.LastIndex(location, "/")] +} + // GetAzureBlobLocation parses path-style location http url to return in azure:// format // https://myproject.blob.core.windows.net/test-bucket/test-object.csv --> azure://myproject.blob.core.windows.net/test-bucket/test-object.csv func GetAzureBlobLocation(location string) string { @@ -435,9 +434,7 @@ func GetAzureBlobLocation(location string) string { // GetAzureBlobLocationFolder returns the folder path for an azure storage object // https://myproject.blob.core.windows.net/test-bucket/myfolder/test-object.csv --> azure://myproject.blob.core.windows.net/myfolder func GetAzureBlobLocationFolder(location string) string { - s3Location := GetAzureBlobLocation(location) - lastPos := strings.LastIndex(s3Location, "/") - return s3Location[:lastPos] + return GetLocationFolder(GetAzureBlobLocation(location)) } func GetS3Locations(loadFiles []LoadFile) []LoadFile { @@ -819,16 +816,6 @@ func GetLoadFileFormat(loadFileType string) string { } } -func GetDateRangeList(start, end time.Time, dateFormat string) (dateRange []string) { - if (start == time.Time{} || end == time.Time{}) { - return - } - for d := start; !d.After(end); d = d.AddDate(0, 0, 1) { - dateRange = append(dateRange, d.Format(dateFormat)) - } - return -} - func StagingTablePrefix(provider string) string { return ToProviderCase(provider, stagingTablePrefix) } diff --git a/warehouse/utils/utils_test.go b/warehouse/utils/utils_test.go index e9a305c1a6..b476578b5d 100644 --- a/warehouse/utils/utils_test.go +++ b/warehouse/utils/utils_test.go @@ -1193,14 +1193,6 @@ var _ = Describe("Utils", func() { Entry(nil, json.RawMessage(`{"k1": { "k2": "v2" }}`), model.Schema{"k1": {"k2": "v2"}}), ) - DescribeTable("Get date range list", func(start, end time.Time, format string, expected []string) { - Expect(GetDateRangeList(start, end, format)).To(Equal(expected)) - }, - Entry("Same day", time.Now(), time.Now(), "2006-01-02", []string{time.Now().Format("2006-01-02")}), - Entry("Multiple days", time.Now(), time.Now().AddDate(0, 0, 1), "2006-01-02", []string{time.Now().Format("2006-01-02"), time.Now().AddDate(0, 0, 1).Format("2006-01-02")}), - Entry("No days", nil, nil, "2006-01-02", nil), - ) - DescribeTable("Staging table prefix", func(provider string) { Expect(StagingTablePrefix(provider)).To(Equal(ToProviderCase(provider, "rudder_staging_"))) },