Skip to content

Commit

Permalink
Merge pull request #29 from Photoroom/ben/ml-2346-move-everything-to-…
Browse files Browse the repository at this point in the history
…the-public-datago

[refactor] Yet another small set of arch improvements + adding a lot of python tests
  • Loading branch information
blefaudeux authored Nov 13, 2024
2 parents 7c0162b + d1cd2d8 commit 6531f57
Show file tree
Hide file tree
Showing 19 changed files with 353 additions and 119 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/gopy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ jobs:

run: |
ls
python3 -m pip install -r requirements.txt
pytest -xv python/tests/*
python3 -m pip install -r requirements-tests.txt
pytest -xv python/*
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ go.work.sum
build

__pycache__
*.pyc
16 changes: 9 additions & 7 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import (

func main() {
// Define flags
config := datago.DatagoConfig{}
config.SetDefaults()
config := datago.GetDatagoConfig()

sourceConfig := datago.GeneratorFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig := datago.SourceFileSystemConfig{RootPath: os.Getenv("DATAROOM_TEST_FILESYSTEM")}
sourceConfig.PageSize = 10
sourceConfig.Rank = 0
sourceConfig.WorldSize = 1

config.ImageConfig = datago.ImageTransformConfig{
DefaultImageSize: 1024,
DownsamplingRatio: 32,
Expand All @@ -26,8 +28,8 @@ func main() {
config.Concurrency = *flag.Int("concurrency", 64, "The number of concurrent http requests to make")
config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served")
config.Limit = *flag.Int("limit", 2000, "The number of items to fetch")

limit := flag.Int("limit", 2000, "The number of items to fetch")
profile := flag.Bool("profile", false, "Whether to profile the code")

// Parse the flags and instantiate the client
Expand Down Expand Up @@ -65,10 +67,10 @@ func main() {

// Fetch all of the binary payloads as they become available
// NOTE: This is useless, just making sure that we empty the payloads channel
for i := 0; i < *limit; i++ {
for {
sample := dataroom_client.GetSample()
if sample.ID == "" {
fmt.Println("No more samples ", i, " samples served")
fmt.Println("No more samples")
break
}
}
Expand All @@ -78,7 +80,7 @@ func main() {

// Calculate the elapsed time
elapsedTime := time.Since(startTime)
fps := float64(*limit) / elapsedTime.Seconds()
fps := float64(config.Limit) / elapsedTime.Seconds()
fmt.Printf("Total execution time: %.2f \n", elapsedTime.Seconds())
fmt.Printf("Average throughput: %.2f samples per second\n", fps)
}
1 change: 1 addition & 0 deletions pkg/architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Sample struct {
ID string
Source string
Attributes map[string]interface{}
DuplicateState int
Image ImagePayload
Masks map[string]ImagePayload
AdditionalImages map[string]ImagePayload
Expand Down
68 changes: 49 additions & 19 deletions pkg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type DataSourceConfig struct {
PageSize int `json:"page_size"`
Rank int `json:"rank"`
WorldSize int `json:"world_size"`
Limit int `json:"limit"`
}

type ImageTransformConfig struct {
Expand All @@ -41,7 +42,7 @@ type ImageTransformConfig struct {
PreEncodeImages bool `json:"pre_encode_images"`
}

func (c *ImageTransformConfig) SetDefaults() {
func (c *ImageTransformConfig) setDefaults() {
c.DefaultImageSize = 512
c.DownsamplingRatio = 16
c.MinAspectRatio = 0.5
Expand All @@ -57,17 +58,25 @@ type DatagoConfig struct {
PrefetchBufferSize int `json:"prefetch_buffer_size"`
SamplesBufferSize int `json:"samples_buffer_size"`
Concurrency int `json:"concurrency"`
Limit int `json:"limit"`
}

func (c *DatagoConfig) SetDefaults() {
dbConfig := GeneratorDBConfig{}
dbConfig.SetDefaults()
func (c *DatagoConfig) setDefaults() {
dbConfig := SourceDBConfig{}
dbConfig.setDefaults()
c.SourceConfig = dbConfig

c.ImageConfig.SetDefaults()
c.ImageConfig.setDefaults()
c.PrefetchBufferSize = 64
c.SamplesBufferSize = 32
c.Concurrency = 64
c.Limit = 0
}

func GetDatagoConfig() DatagoConfig {
config := DatagoConfig{}
config.setDefaults()
return config
}

func DatagoConfigFromJSON(jsonString string) DatagoConfig {
Expand All @@ -80,21 +89,30 @@ func DatagoConfigFromJSON(jsonString string) DatagoConfig {

sourceConfig, err := json.Marshal(tempConfig["source_config"])
if err != nil {
fmt.Println("Error marshalling source_config", tempConfig["source_config"], err)
log.Panicf("Error marshalling source_config: %v", err)
}

// Unmarshal the source config based on the source type
// NOTE: The undefined fields will follow the default values
switch tempConfig["source_type"] {
case string(SourceTypeDB):
var dbConfig GeneratorDBConfig
dbConfig := SourceDBConfig{}
dbConfig.setDefaults()

err = json.Unmarshal(sourceConfig, &dbConfig)
if err != nil {
fmt.Println("Error unmarshalling DB config", sourceConfig, err)
log.Panicf("Error unmarshalling DB config: %v", err)
}
config.SourceConfig = dbConfig
case string(SourceTypeFileSystem):
var fsConfig GeneratorFileSystemConfig
fsConfig := SourceFileSystemConfig{}
fsConfig.setDefaults()

err = json.Unmarshal(sourceConfig, &fsConfig)
if err != nil {
fmt.Println("Error unmarshalling Filesystem config", sourceConfig, err)
log.Panicf("Error unmarshalling FileSystem config: %v", err)
}
config.SourceConfig = fsConfig
Expand Down Expand Up @@ -127,7 +145,9 @@ type DatagoClient struct {
waitGroup *sync.WaitGroup
cancel context.CancelFunc

ImageConfig ImageTransformConfig
imageConfig ImageTransformConfig
servedSamples int
limit int

// Flexible generator, backend and dispatch goroutines
generator Generator
Expand Down Expand Up @@ -157,14 +177,14 @@ func GetClient(config DatagoConfig) *DatagoClient {
fmt.Println(reflect.TypeOf(config.SourceConfig))

switch config.SourceConfig.(type) {
case GeneratorDBConfig:
case SourceDBConfig:
fmt.Println("Creating a DB-backed dataloader")
dbConfig := config.SourceConfig.(GeneratorDBConfig)
dbConfig := config.SourceConfig.(SourceDBConfig)
generator = newDatagoGeneratorDB(dbConfig)
backend = BackendHTTP{config: &dbConfig, concurrency: config.Concurrency}
case GeneratorFileSystemConfig:
case SourceFileSystemConfig:
fmt.Println("Creating a FileSystem-backed dataloader")
fsConfig := config.SourceConfig.(GeneratorFileSystemConfig)
fsConfig := config.SourceConfig.(SourceFileSystemConfig)
generator = newDatagoGeneratorFileSystem(fsConfig)
backend = BackendFileSystem{config: &config, concurrency: config.Concurrency}
default:
Expand All @@ -177,7 +197,9 @@ func GetClient(config DatagoConfig) *DatagoClient {
chanPages: make(chan Pages, 2),
chanSampleMetadata: make(chan SampleDataPointers, config.PrefetchBufferSize),
chanSamples: make(chan Sample, config.SamplesBufferSize),
ImageConfig: config.ImageConfig,
imageConfig: config.ImageConfig,
servedSamples: 0,
limit: config.Limit,
context: nil,
cancel: nil,
waitGroup: nil,
Expand Down Expand Up @@ -218,13 +240,13 @@ func (c *DatagoClient) Start() {
// Optionally crop and resize the images and masks on the fly
var arAwareTransform *ARAwareTransform = nil

if c.ImageConfig.CropAndResize {
if c.imageConfig.CropAndResize {
fmt.Println("Cropping and resizing images")
fmt.Println("Base image size | downsampling ratio | min | max:", c.ImageConfig.DefaultImageSize, c.ImageConfig.DownsamplingRatio, c.ImageConfig.MinAspectRatio, c.ImageConfig.MaxAspectRatio)
arAwareTransform = newARAwareTransform(c.ImageConfig)
fmt.Println("Base image size | downsampling ratio | min | max:", c.imageConfig.DefaultImageSize, c.imageConfig.DownsamplingRatio, c.imageConfig.MinAspectRatio, c.imageConfig.MaxAspectRatio)
arAwareTransform = newARAwareTransform(c.imageConfig)
}

if c.ImageConfig.PreEncodeImages {
if c.imageConfig.PreEncodeImages {
fmt.Println("Pre-encoding images, we'll return serialized JPG and PNG bytes")
}

Expand All @@ -247,23 +269,31 @@ func (c *DatagoClient) Start() {
wg.Add(1)
go func() {
defer wg.Done()
c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.ImageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them
c.backend.collectSamples(c.chanSampleMetadata, c.chanSamples, arAwareTransform, c.imageConfig.PreEncodeImages) // Fetch the payloads and and deserialize them
}()

c.waitGroup = &wg
}

// Get a deserialized sample from the client
func (c *DatagoClient) GetSample() Sample {
if c.cancel == nil {
if c.cancel == nil && c.servedSamples == 0 {
fmt.Println("Dataroom client not started. Starting it on the first sample, this adds some initial latency")
fmt.Println("Please consider starting the client in anticipation by calling .Start()")
c.Start()
}

if c.limit > 0 && c.servedSamples == c.limit {
fmt.Println("Reached the limit of samples to serve, stopping the client")
c.Stop()
return Sample{}
}

if sample, ok := <-c.chanSamples; ok {
c.servedSamples++
return sample
}

fmt.Println("chanSamples closed, no more samples to serve")
return Sample{}
}
Expand Down
35 changes: 29 additions & 6 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type urlLatent struct {
type dbSampleMetadata struct {
Id string `json:"id"`
Attributes map[string]interface{} `json:"attributes"`
DuplicateState int `json:"duplicate_state"`
ImageDirectURL string `json:"image_direct_url"`
Latents []urlLatent `json:"latents"`
Tags []string `json:"tags"`
Expand Down Expand Up @@ -73,7 +74,7 @@ type dbRequest struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorDBConfig struct {
type SourceDBConfig struct {
DataSourceConfig
Sources string `json:"sources"`
RequireImages bool `json:"require_images"`
Expand All @@ -86,7 +87,9 @@ type GeneratorDBConfig struct {
LacksMasks string `json:"lacks_masks"`
HasLatents string `json:"has_latents"`
LacksLatents string `json:"lacks_latents"`
ReturnLatents string `json:"return_latents"`

ReturnLatents string `json:"return_latents"`
ReturnDuplicateState bool `json:"return_duplicate_state"`

MinShortEdge int `json:"min_short_edge"`
MaxShortEdge int `json:"max_short_edge"`
Expand All @@ -95,7 +98,7 @@ type GeneratorDBConfig struct {
RandomSampling bool `json:"random_sampling"`
}

func (c *GeneratorDBConfig) SetDefaults() {
func (c *SourceDBConfig) setDefaults() {
c.PageSize = 512
c.Rank = -1
c.WorldSize = -1
Expand All @@ -112,15 +115,17 @@ func (c *GeneratorDBConfig) SetDefaults() {
c.HasLatents = ""
c.LacksLatents = ""
c.ReturnLatents = ""
c.ReturnDuplicateState = false

c.MinShortEdge = -1
c.MaxShortEdge = -1
c.MinPixelCount = -1
c.MaxPixelCount = -1
c.RandomSampling = false

}

func (c *GeneratorDBConfig) getDbRequest() dbRequest {
func (c *SourceDBConfig) getDbRequest() dbRequest {

fields := "attributes,image_direct_url"
if len(c.HasLatents) > 0 || len(c.HasMasks) > 0 {
Expand All @@ -142,6 +147,11 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
fmt.Println("Including embeddings")
}

if c.ReturnDuplicateState {
fields += ",duplicate_state"
fmt.Println("Including duplicate state")
}

// Report some config data
fmt.Println("Rank | World size:", c.Rank, c.WorldSize)
fmt.Println("Sources:", c.Sources, "| Fields:", fields)
Expand All @@ -153,6 +163,13 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
return fmt.Sprintf("%d", val)
}

// Align rank and worldsize with the partitioning
if c.WorldSize < 2 {
// No partitioning
c.WorldSize = -1
c.Rank = -1
}

return dbRequest{
fields: fields,
sources: c.Sources,
Expand All @@ -176,12 +193,18 @@ func (c *GeneratorDBConfig) getDbRequest() dbRequest {
}
}

func GetSourceDBConfig() SourceDBConfig {
config := SourceDBConfig{}
config.setDefaults()
return config
}

type datagoGeneratorDB struct {
baseRequest http.Request
config GeneratorDBConfig
config SourceDBConfig
}

func newDatagoGeneratorDB(config GeneratorDBConfig) datagoGeneratorDB {
func newDatagoGeneratorDB(config SourceDBConfig) datagoGeneratorDB {
request := config.getDbRequest()

api_key := os.Getenv("DATAROOM_API_KEY")
Expand Down
14 changes: 10 additions & 4 deletions pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,31 @@ type fsSampleMetadata struct {
}

// -- Define the front end goroutine ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type GeneratorFileSystemConfig struct {
type SourceFileSystemConfig struct {
DataSourceConfig
RootPath string `json:"root_path"`
}

func (c *GeneratorFileSystemConfig) SetDefaults() {
func (c *SourceFileSystemConfig) setDefaults() {
c.PageSize = 512
c.Rank = 0
c.WorldSize = 1

c.RootPath = os.Getenv("DATAROOM_TEST_FILESYSTEM")
}

func GetSourceFileSystemConfig() SourceFileSystemConfig {
config := SourceFileSystemConfig{}
config.setDefaults()
return config
}

type datagoGeneratorFileSystem struct {
extensions set
config GeneratorFileSystemConfig
config SourceFileSystemConfig
}

func newDatagoGeneratorFileSystem(config GeneratorFileSystemConfig) datagoGeneratorFileSystem {
func newDatagoGeneratorFileSystem(config SourceFileSystemConfig) datagoGeneratorFileSystem {
supported_img_extensions := []string{".jpg", ".jpeg", ".png", ".JPEG", ".JPG", ".PNG"}
var extensionsMap = make(set)
for _, ext := range supported_img_extensions {
Expand Down
2 changes: 1 addition & 1 deletion pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware
return nil, -1., err_report
}

func fetchSample(config *GeneratorDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
func fetchSample(config *SourceDBConfig, http_client *http.Client, sample_result dbSampleMetadata, transform *ARAwareTransform, pre_encode_image bool) *Sample {
// Per sample work:
// - fetch the raw payloads
// - deserialize / decode, depending on the types
Expand Down
2 changes: 1 addition & 1 deletion pkg/worker_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

type BackendHTTP struct {
config *GeneratorDBConfig
config *SourceDBConfig
concurrency int
}

Expand Down
Loading

0 comments on commit 6531f57

Please sign in to comment.