Skip to content

Commit

Permalink
WIP, would need more work but short on time
Browse files Browse the repository at this point in the history
[x] initial refactor
[x] adding a barebones filesystem dataloader
[x] barebones unit test -> broken
[ ] benchmark on IN1k
  • Loading branch information
blefaudeux committed Oct 20, 2024
1 parent dd80544 commit 538aba8
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 109 deletions.
21 changes: 20 additions & 1 deletion python_tests/datago_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datago import datago
import pytest
import os
from PIL import Image


def get_test_source():
return os.getenv("DATAROOM_TEST_SOURCE")


def test_get_sample():
def test_get_sample_db():
# Check that we can instantiate a client and get a sample, nothing more
config = datago.GetDefaultConfig()
config.source = get_test_source()
Expand All @@ -17,6 +18,24 @@ def test_get_sample():
assert data.ID != ""


def test_get_sample_filesystem():
cwd = os.getcwd()

# Dump a sample image to the filesystem
img = Image.new("RGB", (100, 100))
img.save(cwd + "/test.png")

# Check that we can instantiate a client and get a sample, nothing more
config = datago.GetDefaultConfig()
config.SourceType = "filesystem"
config.Sources = cwd
config.sample = 1

client = datago.GetClient(config)
data = client.GetSample()
assert data.ID != ""


# TODO: Backport all the image correctness tests

if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytest
pytest
pillow
3 changes: 2 additions & 1 deletion src/cmd/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
func main() {
// Define flags
client_config := datago.GetDefaultConfig()
client_config.SourceType = datago.SourceTypeFileSystem
client_config.Sources = os.Getenv("DATAROOM_TEST_FILESYSTEM")
client_config.DefaultImageSize = 1024
client_config.DownsamplingRatio = 32

Expand All @@ -21,7 +23,6 @@ func main() {
client_config.PrefetchBufferSize = *flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
client_config.SamplesBufferSize = *flag.Int("item_ready_buffer", 128, "The number of items ready to be served")

client_config.Sources = *flag.String("source", "GETTY", "The source for the items")
client_config.RequireImages = *flag.Bool("require_images", true, "Whether the items require images")
client_config.RequireEmbeddings = *flag.Bool("require_embeddings", false, "Whether the items require the DB embeddings")

Expand Down
52 changes: 52 additions & 0 deletions src/pkg/client/architecture.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package datago

import "context"

// --- Sample data structures - these will be exposed to the Python world ---------------------------------------------------------------------------------------------------------------------------------------------------------------
type LatentPayload struct {
Data []byte
Len int
DataPtr uintptr
}

type ImagePayload struct {
Data []byte
OriginalHeight int // Good indicator of the image frequency dbResponse at the current resolution
OriginalWidth int
Height int // Useful to decode the current payload
Width int
Channels int
DataPtr uintptr
}

type Sample struct {
ID string
Source string
Attributes map[string]interface{}
Image ImagePayload
Masks map[string]ImagePayload
AdditionalImages map[string]ImagePayload
Latents map[string]LatentPayload
CocaEmbedding []float32
Tags []string
}

// --- Generator and Backend interfaces ---------------------------------------------------------------------------------------------------------------------------------------------------------------

// The generator will be responsible for producing pages of metadata which can be dispatched
// to the dispatch goroutine. The metadata will be used to fetch the actual payloads

type SampleDataPointers interface{}

type Pages struct {
samplesDataPointers []SampleDataPointers
}

type Generator interface {
generatePages(ctx context.Context, chanPages chan Pages)
}

// The backend will be responsible for fetching the payloads and deserializing them
type Backend interface {
collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform)
}
67 changes: 67 additions & 0 deletions src/pkg/client/backend_filesystem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package datago

import (
"fmt"
"os"
)

type BackendFileSystem struct {
config *DatagoConfig
}

func loadSample(config *DatagoConfig, filesystem_sample fsSampleMetadata, transform *ARAwareTransform) *Sample {
// Load the file into []bytes
bytes_buffer, err := os.ReadFile(filesystem_sample.filePath)
if err != nil {
fmt.Println("Error reading file:", filesystem_sample.filePath)
return nil
}

img_payload, _, err := imageFromBuffer(bytes_buffer, transform, -1., config.PreEncodeImages, false)
if err != nil {
fmt.Println("Error loading image:", filesystem_sample.fileName)
return nil
}

return &Sample{ID: filesystem_sample.fileName,
Image: *img_payload,
}
}

func (b BackendFileSystem) collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform) {

ack_channel := make(chan bool)

sampleWorker := func() {
for {
item_to_fetch, open := <-chanSampleMetadata
if !open {
ack_channel <- true
return
}

// Cast the item to fetch to the correct type
filesystem_sample, ok := item_to_fetch.(fsSampleMetadata)
if !ok {
panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured")
}

sample := loadSample(b.config, filesystem_sample, transform)
if sample != nil {
chanSamples <- *sample
}
}
}

// Start the workers and work on the metadata channel
for i := 0; i < b.config.ConcurrentDownloads; i++ {
go sampleWorker()
}

// Wait for all the workers to be done or overall context to be cancelled
for i := 0; i < b.config.ConcurrentDownloads; i++ {
<-ack_channel
}
close(chanSamples)
fmt.Println("No more items to serve, wrapping up")
}
10 changes: 8 additions & 2 deletions src/pkg/client/backend_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ type BackendHTTP struct {
config *DatagoConfig
}

func (b BackendHTTP) collectSamples(chanSampleMetadata chan dbSampleMetadata, chanSamples chan Sample, transform *ARAwareTransform) {
func (b BackendHTTP) collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform) {

ack_channel := make(chan bool)

Expand All @@ -25,7 +25,13 @@ func (b BackendHTTP) collectSamples(chanSampleMetadata chan dbSampleMetadata, ch
return
}

sample := fetchSample(b.config, &http_client, item_to_fetch, transform)
// Cast the item to fetch to the correct type
http_sample, ok := item_to_fetch.(dbSampleMetadata)
if !ok {
panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured")
}

sample := fetchSample(b.config, &http_client, http_sample, transform)
if sample != nil {
chanSamples <- *sample
}
Expand Down
Loading

0 comments on commit 538aba8

Please sign in to comment.