Skip to content

Commit

Permalink
Merge pull request #23 from Photoroom/ben/rank_filesystem
Browse files Browse the repository at this point in the history
[feat] Rank support for filesystem + more DB API support
  • Loading branch information
blefaudeux authored Nov 6, 2024
2 parents 1d7a98d + a026a0a commit 290e106
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 23 deletions.
4 changes: 3 additions & 1 deletion pkg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ const (

// Nested configuration structures for the client
type DataSourceConfig struct {
PageSize int `json:"page_size"`
PageSize int `json:"page_size"`
Rank int `json:"rank"`
WorldSize int `json:"world_size"`
}

type ImageTransformConfig struct {
Expand Down
68 changes: 48 additions & 20 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,20 @@ type dbRequest struct {
hasMasks string
lacksMasks string

hasLatents string
lacksLatents string
hasLatents string
lacksLatents string
returnLatents string

minShortEdge string
maxShortEdge string

minPixelCount string
maxPixelCount string

randomSampling bool

partitionsCount string
partition string
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
Expand All @@ -74,11 +86,20 @@ type GeneratorDBConfig struct {
LacksMasks string `json:"lacks_masks"`
HasLatents string `json:"has_latents"`
LacksLatents string `json:"lacks_latents"`
Rank uint32 `json:"rank"`
WorldSize uint32 `json:"world_size"`
ReturnLatents string `json:"return_latents"`

MinShortEdge int `json:"min_short_edge"`
MaxShortEdge int `json:"max_short_edge"`
MinPixelCount int `json:"min_pixel_count"`
MaxPixelCount int `json:"max_pixel_count"`
RandomSampling bool `json:"random_sampling"`
}

func (c *GeneratorDBConfig) SetDefaults() {
c.PageSize = 512
c.Rank = -1
c.WorldSize = -1

c.Sources = ""
c.RequireImages = true
c.RequireEmbeddings = false
Expand All @@ -90,9 +111,13 @@ func (c *GeneratorDBConfig) SetDefaults() {
c.LacksMasks = ""
c.HasLatents = ""
c.LacksLatents = ""
c.Rank = 0
c.WorldSize = 1
c.PageSize = 512
c.ReturnLatents = ""

c.MinShortEdge = -1
c.MaxShortEdge = -1
c.MinPixelCount = -1
c.MaxPixelCount = -1
c.RandomSampling = false
}

func (c *GeneratorDBConfig) getDbRequest() dbRequest {
Expand Down Expand Up @@ -121,6 +146,13 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
fmt.Println("Rank | World size:", c.Rank, c.WorldSize)
fmt.Println("Sources:", c.Sources, "| Fields:", fields)

sanitizeInt := func(val int) string {
if val == -1 {
return ""
}
return fmt.Sprintf("%d", val)
}

return dbRequest{
fields: fields,
sources: c.Sources,
Expand All @@ -133,6 +165,14 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
lacksMasks: c.LacksMasks,
hasLatents: c.HasLatents,
lacksLatents: c.LacksLatents,
returnLatents: c.HasLatents, // Could be exposed as it's done internally
minShortEdge: sanitizeInt(c.MinShortEdge),
maxShortEdge: sanitizeInt(c.MaxShortEdge),
minPixelCount: sanitizeInt(c.MinPixelCount),
maxPixelCount: sanitizeInt(c.MaxPixelCount),
randomSampling: c.RandomSampling,
partitionsCount: sanitizeInt(c.WorldSize),
partition: sanitizeInt(c.Rank),
}
}

Expand All @@ -157,19 +197,7 @@ func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB {
fmt.Println("Dataroom API URL:", api_url)
fmt.Println("Dataroom API KEY last characters:", getLast5Chars(api_key))

generatorDBConfig := GeneratorDBConfig{
RequireImages: config.RequireImages,
RequireEmbeddings: config.RequireEmbeddings,
HasMasks: config.HasMasks,
LacksMasks: config.LacksMasks,
HasLatents: config.HasLatents,
LacksLatents: config.LacksLatents,
Sources: config.Sources,
Rank: config.Rank,
WorldSize: config.WorldSize,
}

return datagoGeneratorDB{baseRequest: *getHTTPRequest(api_url, api_key, request), config: generatorDBConfig}
return datagoGeneratorDB{baseRequest: *getHTTPRequest(api_url, api_key, request), config: config}
}

func (f datagoGeneratorDB) generatePages(ctx context.Context, chanPages chan Pages) {
Expand Down
17 changes: 15 additions & 2 deletions pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datago

import (
"context"
"crypto/sha256"
"fmt"
"os"
"path/filepath"
Expand All @@ -24,6 +25,9 @@ type GeneratorFileSystemConfig struct {

func (c *GeneratorFileSystemConfig) SetDefaults() {
c.PageSize = 512
c.Rank = 0
c.WorldSize = 1

c.RootPath = os.Getenv("DATAROOM_TEST_FILESYSTEM")
}

Expand All @@ -44,6 +48,12 @@ func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGenera
return datagoGeneratorFileSystem{config: config, extensions: extensionsMap}
}

// hash function to distribute files across ranks
func hash(s string) int {
h := sha256.Sum256([]byte(s))
return int(h[0]) // Convert the first byte of the hash to an integer
}

func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages chan Pages) {
// Walk over the directory and feed the results to the items channel
// This is meant to be run in a goroutine
Expand All @@ -54,9 +64,12 @@ func (f datagoGeneratorFileSystem) generatePages(ctx context.Context, chanPages
if err != nil {
return err
}

if !info.IsDir() && f.extensions.Contains(filepath.Ext(path)) {
new_sample := fsSampleMetadata{FilePath: path, FileName: info.Name()}
samples = append(samples, SampleDataPointers(new_sample))
if f.config.WorldSize > 1 && hash(path)%f.config.WorldSize != f.config.Rank || f.config.WorldSize == 1 {
new_sample := fsSampleMetadata{FilePath: path, FileName: info.Name()}
samples = append(samples, SampleDataPointers(new_sample))
}
}

// Check if we have enough files to send a page
Expand Down
9 changes: 9 additions & 0 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@ func getHTTPRequest(api_url string, api_key string, request dbRequest) *http.Req
maybeAddField(&req, "has_latents", request.hasLatents)
maybeAddField(&req, "lacks_latents", request.lacksLatents)
maybeAddField(&req, "return_latents", return_latents)

maybeAddField(&req, "short_edge__gte", request.minShortEdge)
maybeAddField(&req, "short_edge__lte", request.maxShortEdge)
maybeAddField(&req, "pixel_count__gte", request.minPixelCount)
maybeAddField(&req, "pixel_count__lte", request.maxPixelCount)

maybeAddField(&req, "partitions_count", request.partitionsCount)
maybeAddField(&req, "partition", request.partition)

request_url.URL.RawQuery = req.Encode()
fmt.Println("Request URL:", request_url.URL.String())
fmt.Println()
Expand Down
70 changes: 70 additions & 0 deletions tests/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datago_test

import (
"fmt"
"os"
"testing"

Expand Down Expand Up @@ -220,3 +221,72 @@ func TestImageBufferCompression(t *testing.T) {
t.Errorf("Error decoding mask buffer")
}
}

func TestStrings(t *testing.T) {
clientConfig := get_default_test_config()
client := datago.GetClient(clientConfig)
client.Start()

for i := 0; i < 10; i++ {
sample := client.GetSample()

// Assert that no error occurred
if sample.ID == "" {
t.Errorf("GetSample returned an unexpected error")
}

// Check that we can decode all the strings
if string(sample.ID) == "" {
t.Errorf("Expected non-empty string")
}

fmt.Println(string(sample.ID))
}
client.Stop()
}

func TestRanks(t *testing.T) {
clientConfig := get_default_test_config()
clientConfig.SamplesBufferSize = 1

dbConfig := clientConfig.SourceConfig.(datago.GeneratorDBConfig)
dbConfig.WorldSize = 2
dbConfig.Rank = 0
clientConfig.SourceConfig = dbConfig

client_0 := datago.GetClient(clientConfig)
client_0.Start()

dbConfig.Rank = 1
clientConfig.SourceConfig = dbConfig
client_1 := datago.GetClient(clientConfig)
client_1.Start()

samples_0 := make(map[string]int)
samples_1 := make(map[string]int)

for i := 0; i < 10; i++ {
sample_0 := client_0.GetSample()
sample_1 := client_1.GetSample()

if sample_0.ID == "" || sample_1.ID == "" {
t.Errorf("GetSample returned an unexpected error")
}

samples_0[sample_0.ID] = 1
samples_1[sample_1.ID] = 1

}

// Check that there are no keys in common in between the two samples
for k := range samples_0 {
if _, exists := samples_1[k]; exists {
t.Errorf("Samples are not distributed across ranks")
}
}

client_0.Stop()
client_1.Stop()
}

// FIXME: Could do with a lot of tests on the filesystem side

0 comments on commit 290e106

Please sign in to comment.