Skip to content

Commit

Permalink
Merge pull request #51 from Photoroom/tarek/fix-duplicate-state-add-f…
Browse files Browse the repository at this point in the history
…iltering

fix returnDuplicateState option and add duplicate_state filtering
  • Loading branch information
blefaudeux authored Nov 27, 2024
2 parents 92678e5 + e0deb90 commit 88c7780
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
23 changes: 18 additions & 5 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type dbRequest struct {
minPixelCount string
maxPixelCount string

duplicateState string

randomSampling bool

partitionsCount string
Expand All @@ -95,10 +97,13 @@ type SourceDBConfig struct {
ReturnLatents string `json:"return_latents"`
ReturnDuplicateState bool `json:"return_duplicate_state"`

MinShortEdge int `json:"min_short_edge"`
MaxShortEdge int `json:"max_short_edge"`
MinPixelCount int `json:"min_pixel_count"`
MaxPixelCount int `json:"max_pixel_count"`
MinShortEdge int `json:"min_short_edge"`
MaxShortEdge int `json:"max_short_edge"`
MinPixelCount int `json:"min_pixel_count"`
MaxPixelCount int `json:"max_pixel_count"`

DuplicateState int `json:"duplicate_state"`

RandomSampling bool `json:"random_sampling"`
}

Expand Down Expand Up @@ -127,7 +132,7 @@ func (c *SourceDBConfig) setDefaults() {
c.MinPixelCount = -1
c.MaxPixelCount = -1
c.RandomSampling = false

c.DuplicateState = -1
}

func (c *SourceDBConfig) getDbRequest() dbRequest {
Expand Down Expand Up @@ -175,6 +180,11 @@ func (c *SourceDBConfig) getDbRequest() dbRequest {
c.Rank = -1
}

duplicateState := sanitizeInt(c.DuplicateState)
if duplicateState == "0" {
duplicateState = "None"
}

return dbRequest{
fields: fields,
sources: c.Sources,
Expand All @@ -194,6 +204,7 @@ func (c *SourceDBConfig) getDbRequest() dbRequest {
minPixelCount: sanitizeInt(c.MinPixelCount),
maxPixelCount: sanitizeInt(c.MaxPixelCount),
randomSampling: c.RandomSampling,
duplicateState: duplicateState,
partitionsCount: sanitizeInt(c.WorldSize),
partition: sanitizeInt(c.Rank),
}
Expand Down Expand Up @@ -367,6 +378,8 @@ func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Req
maybeAddField(&req, "pixel_count__gte", request.minPixelCount)
maybeAddField(&req, "pixel_count__lte", request.maxPixelCount)

maybeAddField(&req, "duplicate_state", request.duplicateState)

maybeAddField(&req, "partitions_count", request.partitionsCount)
maybeAddField(&req, "partition", request.partition)

Expand Down
1 change: 1 addition & 0 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result
return &Sample{ID: sample_result.Id,
Source: sample_result.Source,
Attributes: sample_result.Attributes,
DuplicateState: sample_result.DuplicateState,
Image: *img_payload,
Latents: latents,
Masks: masks,
Expand Down
19 changes: 19 additions & 0 deletions tests/client_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,3 +454,22 @@ func TestRandomSampling(t *testing.T) {
t.Error("Random sampling is not working")
}
}

func TestDuplicateStateFiltering(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1
dbConfig := clientConfig.SourceConfig.(datago.SourceDBConfig)
dbConfig.DuplicateState = 1
dbConfig.ReturnDuplicateState = true
clientConfig.SourceConfig = dbConfig

client := datago.GetClient(clientConfig)

for i := 0; i < 10; i++ {
sample := client.GetSample()
if sample.DuplicateState != 1 {
t.Errorf("Expected duplicate state to be 1")
}
}
client.Stop()
}

0 comments on commit 88c7780

Please sign in to comment.