diff --git a/pkg/generator_db.go b/pkg/generator_db.go index 5f7907b..ec64046 100644 --- a/pkg/generator_db.go +++ b/pkg/generator_db.go @@ -70,6 +70,8 @@ type dbRequest struct { minPixelCount string maxPixelCount string + duplicateState string + randomSampling bool partitionsCount string @@ -95,10 +97,13 @@ type SourceDBConfig struct { ReturnLatents string `json:"return_latents"` ReturnDuplicateState bool `json:"return_duplicate_state"` - MinShortEdge int `json:"min_short_edge"` - MaxShortEdge int `json:"max_short_edge"` - MinPixelCount int `json:"min_pixel_count"` - MaxPixelCount int `json:"max_pixel_count"` + MinShortEdge int `json:"min_short_edge"` + MaxShortEdge int `json:"max_short_edge"` + MinPixelCount int `json:"min_pixel_count"` + MaxPixelCount int `json:"max_pixel_count"` + + DuplicateState int `json:"duplicate_state"` + RandomSampling bool `json:"random_sampling"` } @@ -127,7 +132,7 @@ func (c *SourceDBConfig) setDefaults() { c.MinPixelCount = -1 c.MaxPixelCount = -1 c.RandomSampling = false - + c.DuplicateState = -1 } func (c *SourceDBConfig) getDbRequest() dbRequest { @@ -175,6 +180,11 @@ func (c *SourceDBConfig) getDbRequest() dbRequest { c.Rank = -1 } + duplicateState := sanitizeInt(c.DuplicateState) + if duplicateState == "0" { + duplicateState = "None" + } + return dbRequest{ fields: fields, sources: c.Sources, @@ -194,6 +204,7 @@ func (c *SourceDBConfig) getDbRequest() dbRequest { minPixelCount: sanitizeInt(c.MinPixelCount), maxPixelCount: sanitizeInt(c.MaxPixelCount), randomSampling: c.RandomSampling, + duplicateState: duplicateState, partitionsCount: sanitizeInt(c.WorldSize), partition: sanitizeInt(c.Rank), } @@ -367,6 +378,8 @@ func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Req maybeAddField(&req, "pixel_count__gte", request.minPixelCount) maybeAddField(&req, "pixel_count__lte", request.maxPixelCount) + maybeAddField(&req, "duplicate_state", request.duplicateState) + maybeAddField(&req, "partitions_count", request.partitionsCount) maybeAddField(&req, "partition", request.partition) diff --git a/pkg/serdes.go b/pkg/serdes.go index 745837a..d7e0f60 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -254,6 +254,7 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result return &Sample{ID: sample_result.Id, Source: sample_result.Source, Attributes: sample_result.Attributes, + DuplicateState: sample_result.DuplicateState, Image: *img_payload, Latents: latents, Masks: masks, diff --git a/tests/client_db_test.go b/tests/client_db_test.go index d15b831..8dd5379 100644 --- a/tests/client_db_test.go +++ b/tests/client_db_test.go @@ -454,3 +454,22 @@ func TestRandomSampling(t *testing.T) { t.Error("Random sampling is not working") } } + +func TestDuplicateStateFiltering(t *testing.T) { + clientConfig := get_default_test_config() + clientConfig.SamplesBufferSize = 1 + dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig) + dbConfig.DuplicateState = 1 + dbConfig.ReturnDuplicateState = true + clientConfig.SourceConfig = dbConfig + + client := datago.GetClient(clientConfig) + + for i := 0; i < 10; i++ { + sample := client.GetSample() + if sample.DuplicateState != 1 { + t.Errorf("Expected duplicate state to be 1") + } + } + client.Stop() +}