diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml deleted file mode 100644 index 76c2812..0000000 --- a/.github/workflows/ci-cd.yml +++ /dev/null @@ -1,71 +0,0 @@ -name: Push to PyPI - -on: - release: - types: [published] - -jobs: - pypi-publish: - name: Upload release to PyPI - - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ubuntu-latest] - rust: [stable] - - environment: - name: release - url: https://pypi.org/p/datago - permissions: - id-token: write - steps: - - uses: actions/checkout@v4 - - - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: ${{ matrix.rust }} - override: true - components: rustfmt, clippy # , cargo-llvm-cov - - - name: Cache dependencies - uses: actions/cache@v3 - with: - path: | - ~/.cargo/bin/ - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ - target/ - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - - - name: Check formatting - run: cargo fmt --all -- --check - - - name: Build - run: cargo build --verbose - - - name: Run tests - env: - DATAROOM_API_KEY: ${{ secrets.DATAROOM_API_KEY }} - DATAROOM_TEST_SOURCE: ${{ secrets.DATAROOM_TEST_SOURCE }} - DATAROOM_API_URL: ${{ secrets.DATAROOM_API_URL }} - - run: cargo test --verbose - - - name: Install maturin - run: | - python3 -m pip install maturin twine - - - name: Build and upload the package - run: | - maturin build -i python3.11 --release --target "x86_64-unknown-linux-gnu" - cd target/wheels - python3 -m pip install --user -v *.whl - ls - twine upload *.whl --repository-url https://pypi.org/p/datago - - - name: Publish package distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/rust-py.yml b/.github/workflows/rust-py.yml index 45a6934..5d91c40 100644 --- a/.github/workflows/rust-py.yml +++ b/.github/workflows/rust-py.yml @@ -1,80 +1,14 @@ -# name: Rust-py - -# on: -# push: -# branches: ["main"] -# pull_request: -# branches: ["main"] - -# jobs: -# build: -# runs-on: ${{ matrix.os }} -# strategy: -# matrix: -# os: [ubuntu-latest] -# rust: [stable] - -# steps: -# - uses: actions/checkout@v3 - -# - name: Install Rust -# uses: actions-rs/toolchain@v1 -# with: -# profile: minimal -# toolchain: ${{ matrix.rust }} -# override: true -# components: rustfmt, clippy # , cargo-llvm-cov - -# - name: Cache dependencies -# uses: actions/cache@v3 -# with: -# path: | -# ~/.cargo/bin/ -# ~/.cargo/registry/index/ -# ~/.cargo/registry/cache/ -# ~/.cargo/git/db/ -# target/ -# key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - -# - name: Set up Python -# uses: actions/setup-python@v5 -# with: -# python-version: "3.11.10" - -# - name: Install maturin -# run: | -# python3 -m pip install maturin - -# - name: Build and install python module -# run: | -# cd datago -# maturin build -i python3.11 --release --target "x86_64-unknown-linux-gnu" -# cd target/wheels -# python3 -m pip install --user -v *.whl - -# - name: Run the python unit tests -# env: -# DATAROOM_API_KEY: ${{ secrets.DATAROOM_API_KEY }} -# DATAROOM_TEST_SOURCE: ${{ secrets.DATAROOM_TEST_SOURCE }} -# DATAROOM_API_URL: ${{ secrets.DATAROOM_API_URL }} - -# run: | -# ls -# # python3 -m pip install -r requirements-tests.txt -# # pytest -xv python/* - - -name: CI +name: Rust-py on: push: branches: - main tags: - - '*' + - "*" pull_request: branches: - - main + - "*" workflow_dispatch: permissions: @@ -91,29 +25,38 @@ jobs: steps: - uses: actions/checkout@v4 + - run: git fetch --prune --unshallow + - uses: actions/setup-python@v5 with: - python-version: 3.x - + python-version: 3.11 - name: Install maturin run: | python3 -m pip install maturin twine - - name: Build and upload the package + # Gather the name of the latest tag on the current main branch + - name: Get the latest tag + id: get_tag + run: echo "tag=$(git describe --tags --abbrev=0)" >> $GITHUB_OUTPUT + + - name: Build the package run: | maturin build -i python3.11 --release --out dist --target "x86_64-unknown-linux-gnu" + mv dist/datago-0.0.0-cp311-cp311-linux_x86_64.whl dist/datago-${{ steps.get_tag.outputs.tag }}-cp311-cp311-linux_x86_64.whl + + - name: Test package + env: + DATAROOM_API_KEY: ${{ secrets.DATAROOM_API_KEY }} + DATAROOM_TEST_SOURCE: ${{ secrets.DATAROOM_TEST_SOURCE }} + DATAROOM_API_URL: ${{ secrets.DATAROOM_API_URL }} - # - name: Build wheels - # uses: PyO3/maturin-action@v1 - # with: - # target: ${{ matrix.platform.target }} - # args: --release --out dist --find-interpreter - # sccache: 'true' - # # manylinux: auto - # docker-options: "--env CIBW_BEFORE_BUILD_LINUX='${{ env.CIBW_BEFORE_BUILD_LINUX }}'" - # env: - # CIBW_BEFORE_BUILD_LINUX: yum -y install openssl openssl-devel perl-IPC-Cmd + # needs to be replaced with the live version of the package / evolving version number + run: | + python3 -m pip install dist/datago-${{ steps.get_tag.outputs.tag }}-cp311-cp311-linux_x86_64.whl + python3 -m pip install -r requirements-tests.txt + cd python + python3 -m pytest -v . - name: Upload wheels uses: actions/upload-artifact@v4 @@ -153,7 +96,7 @@ jobs: - name: Generate artifact attestation uses: actions/attest-build-provenance@v1 with: - subject-path: 'wheels-*/*' + subject-path: "wheels-*/*" - name: Publish to PyPI if: ${{ startsWith(github.ref, 'refs/tags/') }} uses: PyO3/maturin-action@v1 diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 00254fa..4c67188 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -4,7 +4,7 @@ on: push: branches: ["main"] pull_request: - branches: ["main"] + branches: ["*"] env: CARGO_TERM_COLOR: always diff --git a/Cargo.lock b/Cargo.lock index 040daa2..dcfab43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -376,7 +376,7 @@ dependencies = [ [[package]] name = "datago" -version = "0.1.0" +version = "0.0.0" dependencies = [ "clap", "image", @@ -391,6 +391,7 @@ dependencies = [ "threadpool", "tokio", "url", + "walkdir", ] [[package]] @@ -1787,6 +1788,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.27" @@ -2304,6 +2314,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2422,6 +2442,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/Cargo.toml b/Cargo.toml index e35c06c..3b12e04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,5 @@ [package] name = "datago" -version = "0.1.0" edition = "2021" [lib] @@ -27,6 +26,7 @@ pyo3 = { version = "0.23.4", features = ["extension-module"] } threadpool = "1.8.1" num_cpus = "1.16.0" openssl = { version = "0.10", features = ["vendored"] } +walkdir = "2.5.0" [profile.release] opt-level = 3 # Optimize for speed diff --git a/README.md b/README.md index a88e55e..711d1ef 100644 --- a/README.md +++ b/README.md @@ -34,26 +34,9 @@ import json config = { "source_config": { "sources": os.environ.get("DATAROOM_TEST_SOURCE", ""), - "sources_ne": "", - "require_images": True, - "require_embeddings": True, - "tags": "", - "tags_ne": "", - "has_attributes": "", - "lacks_attributes": "", - "has_masks": "", - "lacks_masks": "", - "has_latents": "", - "lacks_latents": "", - "min_short_edge": 0, - "max_short_edge": 0, - "min_pixel_count": -1, - "max_pixel_count": -1, - "duplicate_state": -1, - "random_sampling": False, - "page_size": 10, + "page_size": 500, }, - "limit": 2, + "limit": 20, "rank": 0, "world_size": 1, "samples_buffer_size": 1, diff --git a/pyproject.toml b/pyproject.toml index 575cbcf..67c2a0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [project] name = "datago" -version = "2025.2.1" authors = [ { name="Photoroom", email="team@photoroom.com" }, ] diff --git a/python/benchmark_filesystem.py b/python/benchmark_filesystem.py index 5612eb8..df84200 100644 --- a/python/benchmark_filesystem.py +++ b/python/benchmark_filesystem.py @@ -1,4 +1,3 @@ -from datago import datago # type: ignore import time from tqdm import tqdm import os @@ -18,12 +17,9 @@ def benchmark( ): print(f"Running benchmark for {root_path} - {limit} samples") client_config = { - "source_type": datago.SourceTypeFileSystem, + "source_type": "file", "source_config": { - "page_size": 512, "root_path": root_path, - "rank": 0, - "world_size": 1, }, "image_config": { "crop_and_resize": crop_and_resize, @@ -36,6 +32,8 @@ def benchmark( "prefetch_buffer_size": 128, "samples_buffer_size": 64, "limit": limit, + "rank": 0, + "world_size": 1, } # Make sure in the following that we compare apples to apples, meaning in that case diff --git a/python/dataset.py b/python/dataset.py index 7066e21..b13a938 100644 --- a/python/dataset.py +++ b/python/dataset.py @@ -1,4 +1,4 @@ -from datago import datago +from datago import DatagoClient import json from typing import Dict, Any from raw_types import raw_array_to_pil_image, raw_array_to_numpy @@ -6,8 +6,8 @@ class DatagoIterDataset: def __init__(self, datago_config: Dict[str, Any], return_python_types: bool = True): - self.client = datago.GetClientFromJSON(json.dumps(datago_config)) - self.client.Start() + self.client = DatagoClient(json.dumps(datago_config)) + self.client.start() self.return_python_types = return_python_types self.len = datago_config.get("limit", 1e9) print(self.len) @@ -18,29 +18,25 @@ def __iter__(self): return self def __del__(self): - self.client.Stop() + self.client.stop() def __len__(self): return self.len @staticmethod - def to_python_types(item): - if isinstance(item, datago.ImagePayload): + def to_python_types(item, key): + if key == "attributes": + return json.loads(item) + + if isinstance(item, dict): + # recurvisely convert the dictionary + return {k: DatagoIterDataset.to_python_types(v, k) for k, v in item.items()} + + elif "image" in key: return raw_array_to_pil_image(item) - elif isinstance(item, datago.LatentPayload): + elif "latent" in key: return raw_array_to_numpy(item) - elif isinstance(item, datago.Map_string_interface_): - dict_item = dict(item) - for key, value in filter( - lambda x: isinstance(x[1], str) and x[1].startswith("%!s(float64"), - dict_item.items(), - ): - dict_item[key] = float(value[12:-1]) - return dict_item - elif isinstance(item, datago.go.Slice_string): - return list(item) - - # TODO: Make this recursive, would be an elegant way of handling nested structures + return item def __next__(self): @@ -48,16 +44,16 @@ def __next__(self): if self.count > self.len: raise StopIteration - sample = self.client.GetSample() - if sample.ID == "": + sample = self.client.get_sample() + if sample.id == "": raise StopIteration if self.return_python_types: - # Convert the Go types to Python types + # Convert the Rust types to Python types python_sample = {} for attr in filter(lambda x: "__" not in x, dir(sample)): python_sample[attr.lower()] = self.to_python_types( - getattr(sample, attr) + getattr(sample, attr), attr ) return python_sample @@ -68,14 +64,12 @@ def __next__(self): if __name__ == "__main__": # Example config, using this for filesystem walkthrough would work just as well client_config = client_config = { - "source_type": datago.SourceTypeDB, + "source_type": "db", "source_config": { "page_size": 10, "sources": "COYO", "require_images": True, "has_attributes": "caption_moondream", - "rank": 0, - "world_size": 1, }, "image_config": { "crop_and_resize": True, @@ -88,6 +82,8 @@ def __next__(self): "prefetch_buffer_size": 64, "samples_buffer_size": 128, "limit": 10, + "rank": 0, + "world_size": 1, } dataset = DatagoIterDataset(client_config) for sample in dataset: diff --git a/python/test_datago_db.py b/python/test_datago_db.py index 493e38c..7d97abf 100644 --- a/python/test_datago_db.py +++ b/python/test_datago_db.py @@ -1,7 +1,9 @@ -from datago import datago +from datago import DatagoClient import pytest import os -from python.raw_types import go_array_to_pil_image, go_array_to_numpy +import json + +from raw_types import raw_array_to_pil_image, raw_array_to_numpy from dataset import DatagoIterDataset @@ -13,7 +15,7 @@ def get_test_source() -> str: def get_json_config(): client_config = { - "source_type": datago.SourceTypeDB, + "source_type": "db", "source_config": { "page_size": 10, "sources": get_test_source(), @@ -22,8 +24,6 @@ def get_json_config(): "has_latents": "masked_image", "has_attributes": "caption_coca,caption_cogvlm,caption_moondream", "return_latents": "masked_image", - "rank": 0, - "world_size": 1, }, "image_config": { "crop_and_resize": False, @@ -34,24 +34,21 @@ def get_json_config(): "pre_encode_images": False, }, "prefetch_buffer_size": 64, - "samples_buffer_size": 128, + "samples_buffer_size": 10, "limit": 10, + "rank": 0, + "world_size": 1, } return client_config def test_get_sample_db(): # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.GetDatagoConfig() - client_config.SamplesBufferSize = 10 - - source_config = datago.GetDefaultSourceDBConfig() - source_config.Sources = get_test_source() - client_config.SourceConfig = source_config + client_config = get_json_config() - client = datago.GetClient(client_config) - data = client.GetSample() - assert data.ID != "" + client = DatagoClient(json.dumps(client_config)) + data = client.get_sample() + assert data.id != "" N_SAMPLES = 3 @@ -62,31 +59,33 @@ def test_caption_and_image(): dataset = DatagoIterDataset(client_config, return_python_types=False) def check_image(img, channels=3): - assert img.Height > 0 - assert img.Width > 0 + assert img.height > 0 + assert img.width > 0 - assert img.Height <= img.OriginalHeight - assert img.Width <= img.OriginalWidth - assert img.Channels == channels + assert img.height <= img.original_height + assert img.width <= img.original_width + assert img.channels == channels for i, sample in enumerate(dataset): - assert sample.Source != "" - assert sample.ID != "" + assert sample.source != "" + assert sample.id != "" - assert len(sample.Attributes["caption_coca"]) != len( - sample.Attributes["caption_cogvlm"] - ), "Caption lengths should not be equal" + attributes = json.loads(sample.attributes) + assert len(attributes["caption_coca"]) != len(attributes["caption_cogvlm"]), ( + "Caption lengths should not be equal" + ) - check_image(sample.Image, 3) - check_image(sample.AdditionalImages["masked_image"], 3) - check_image(sample.Masks["segmentation_mask"], 1) + check_image(sample.image, 3) + check_image(sample.additional_images["masked_image"], 3) + check_image(sample.masks["segmentation_mask"], 1) # Check the image decoding - assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert raw_array_to_pil_image(sample.image).mode == "RGB", "Image should be RGB" assert ( - go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + raw_array_to_pil_image(sample.additional_images["masked_image"]).mode + == "RGB" ), "Image should be RGB" - assert go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L", ( + assert raw_array_to_pil_image(sample.masks["segmentation_mask"]).mode == "L", ( "Mask should be L" ) @@ -102,16 +101,16 @@ def test_image_resize(): for i, sample in enumerate(dataset): # Assert that all the images in the sample have the same size assert ( - sample.Image.Height - == sample.AdditionalImages["masked_image"].Height - == sample.Masks["segmentation_mask"].Height - and sample.Image.Height > 0 + sample.image.height + == sample.additional_images["masked_image"].height + == sample.masks["segmentation_mask"].height + and sample.image.height > 0 ) assert ( - sample.Image.Width - == sample.AdditionalImages["masked_image"].Width - == sample.Masks["segmentation_mask"].Width - and sample.Image.Width > 0 + sample.image.width + == sample.additional_images["masked_image"].width + == sample.masks["segmentation_mask"].width + and sample.image.width > 0 ) if i > N_SAMPLES: break @@ -124,7 +123,7 @@ def test_has_tags(): dataset = DatagoIterDataset(client_config, return_python_types=False) sample = next(iter(dataset)) - assert "v4_trainset_hq" in sample.Tags, "v4_trainset_hq should be in the tags" + assert "v4_trainset_hq" in sample.tags, "v4_trainset_hq should be in the tags" def test_empty_image(): @@ -145,16 +144,16 @@ def no_test_jpg_compression(): sample = next(iter(dataset)) - assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert raw_array_to_pil_image(sample.image).mode == "RGB", "Image should be RGB" assert ( - go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + raw_array_to_pil_image(sample.additional_images["masked_image"]).mode == "RGB" ), "Image should be RGB" - assert go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L", ( + assert raw_array_to_pil_image(sample.masks["segmentation_mask"]).mode == "L", ( "Mask should be L" ) # Check the embeddings decoding - assert go_array_to_numpy(sample.CocaEmbedding) is not None, ( + assert raw_array_to_numpy(sample.coca_embedding) is not None, ( "Embedding should be set" ) @@ -168,11 +167,11 @@ def test_original_image(): sample = next(iter(dataset)) - assert go_array_to_pil_image(sample.Image).mode == "RGB", "Image should be RGB" + assert raw_array_to_pil_image(sample.image).mode == "RGB", "Image should be RGB" assert ( - go_array_to_pil_image(sample.AdditionalImages["masked_image"]).mode == "RGB" + raw_array_to_pil_image(sample.additional_images["masked_image"]).mode == "RGB" ), "Image should be RGB" - assert go_array_to_pil_image(sample.Masks["segmentation_mask"]).mode == "L", ( + assert raw_array_to_pil_image(sample.masks["segmentation_mask"]).mode == "L", ( "Mask should be L" ) @@ -183,7 +182,7 @@ def test_duplicate_state(): dataset = DatagoIterDataset(client_config, return_python_types=False) sample = next(iter(dataset)) - assert sample.DuplicateState in [ + assert sample.duplicate_state in [ 0, 1, 2, diff --git a/python/test_datago_filesystem.py b/python/test_datago_filesystem.py index 8d7ba6a..9b6696a 100644 --- a/python/test_datago_filesystem.py +++ b/python/test_datago_filesystem.py @@ -1,28 +1,47 @@ -import os from PIL import Image -from datago import datago +from datago import DatagoClient +import json +import tempfile -# FIXME: Would need to generate more fake data to test this -def no_test_get_sample_filesystem(): - cwd = os.getcwd() +def test_get_sample_filesystem(): + samples = 10 - try: + with tempfile.TemporaryDirectory() as tmpdirname: + cwd = tmpdirname # Dump a sample image to the filesystem - img = Image.new("RGB", (100, 100)) - img.save(cwd + "/test.png") + for i in range(samples): + img = Image.new("RGB", (100, 100)) + img.save(cwd + f"/test_{i}.png") # Check that we can instantiate a client and get a sample, nothing more - client_config = datago.GetDatagoConfig() - client_config.SourceType = "filesystem" - client_config.SamplesBufferSize = 1 + client_config = { + "source_type": "file", + "source_config": { + "root_path": cwd, + }, + "image_config": { + "crop_and_resize": False, + "default_image_size": 512, + "downsampling_ratio": 16, + "min_aspect_ratio": 0.5, + "max_aspect_ratio": 2.0, + "pre_encode_images": False, + }, + "limit": samples, + "prefetch_buffer_size": 64, + "samples_buffer_size": 10, + "rank": 0, + "world_size": 1, + } - source_config = datago.SourceFileSystemConfig() - source_config.RootPath = cwd - source_config.PageSize = 1 + client = DatagoClient(json.dumps(client_config)) + for i in range(samples): + data = client.get_sample() + assert data.id != "" + assert data.image.width == 100 + assert data.image.height == 100 - client = datago.GetClient(client_config, source_config) - data = client.GetSample() - assert data.ID != "" - finally: - os.remove(cwd + "/test.png") + +if __name__ == "__main__": + test_get_sample_filesystem() diff --git a/src/client.rs b/src/client.rs index f511500..09287d0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,29 +1,21 @@ +use crate::generator_files; use crate::generator_http; use crate::image_processing::ARAwareTransform; -use crate::image_processing::ImageTransformConfig; +use crate::structs::{DatagoClientConfig, Sample, SourceType}; +use crate::worker_files; use crate::worker_http; -use pyo3::prelude::*; use kanal::bounded; -use serde::Deserialize; +use pyo3::prelude::*; use std::sync::Arc; use std::thread; use threadpool::ThreadPool; -#[derive(Deserialize)] -struct DatagoClientConfig { - source_config: generator_http::SourceDBConfig, - image_config: Option, - limit: usize, - rank: usize, - world_size: usize, - samples_buffer_size: usize, -} - #[pyclass] pub struct DatagoClient { pub is_started: bool, - source_config: generator_http::SourceDBConfig, + source_type: SourceType, + source_config: serde_json::Value, limit: usize, // Perf settings @@ -37,8 +29,8 @@ pub struct DatagoClient { pages_rx: kanal::Receiver, samples_meta_tx: kanal::Sender, samples_meta_rx: kanal::Receiver, - samples_tx: kanal::Sender, - samples_rx: kanal::Receiver, + samples_tx: kanal::Sender, + samples_rx: kanal::Receiver, worker_done_count: Arc, // Sample processing @@ -77,6 +69,7 @@ impl DatagoClient { DatagoClient { is_started: false, + source_type: config.source_type, source_config: config.source_config, limit: config.limit, num_threads, @@ -105,7 +98,6 @@ impl DatagoClient { // Spawn a new thread which will query the DB and send the pages let pages_tx = self.pages_tx.clone(); - let source_config = self.source_config.clone(); let limit = self.limit; let rank = self.rank; let world_size = self.world_size; @@ -115,16 +107,42 @@ impl DatagoClient { "Rank cannot be greater than or equal to world size" ); - self.pinger = Some(thread::spawn(move || { - generator_http::ping_pages(pages_tx, source_config, rank, world_size, limit); - })); + match self.source_type { + SourceType::Db => { + println!("Using DB as source"); + // convert the source_config to a SourceDBConfig + let source_db_config: generator_http::SourceDBConfig = + serde_json::from_value(self.source_config.clone()).unwrap(); + + self.pinger = Some(thread::spawn(move || { + generator_http::ping_pages(pages_tx, source_db_config, rank, world_size, limit); + })); + } + SourceType::File => { + // convert the source_config to a SourceFileConfig + let source_file_config: generator_files::SourceFileConfig = + serde_json::from_value(self.source_config.clone()).unwrap(); + + println!("Using file as source {}", source_file_config.root_path); + + self.pinger = Some(thread::spawn(move || { + generator_files::ping_files( + pages_tx, + source_file_config, + rank, + world_size, + limit, + ); + })); + } + } // Spawn a new thread which will pull the pages and send the sample metadata let pages_rx = self.pages_rx.clone(); let samples_meta_tx = self.samples_meta_tx.clone(); let limit = self.limit; self.feeder = Some(thread::spawn(move || { - generator_http::pull_pages(pages_rx, samples_meta_tx, limit); + generator_http::dispatch_pages(pages_rx, samples_meta_tx, limit); })); // Spawn threads which will receive the pages @@ -138,27 +156,43 @@ impl DatagoClient { // FIXME: this is a bit ugly, there must be a better way let samples_meta_rx_local = self.samples_meta_rx.clone(); let samples_tx_local = self.samples_tx.clone(); - let thread_local_client = http_client.clone(); let local_image_transform = self.image_transform.clone(); let encode_images = self.encode_images; let worker_done_count = self.worker_done_count.clone(); - self.thread_pool.execute(move || { - worker_http::pull_samples( - thread_local_client, - samples_meta_rx_local, - samples_tx_local, - worker_done_count, - &local_image_transform, - encode_images, - ); - }); + match self.source_type { + SourceType::Db => { + let thread_local_client = http_client.clone(); + + self.thread_pool.execute(move || { + worker_http::pull_samples( + thread_local_client, + samples_meta_rx_local, + samples_tx_local, + worker_done_count, + &local_image_transform, + encode_images, + ); + }); + } + SourceType::File => { + self.thread_pool.execute(move || { + worker_files::pull_samples( + samples_meta_rx_local, + samples_tx_local, + worker_done_count, + &local_image_transform, + encode_images, + ); + }); + } + } } self.is_started = true; } - pub fn get_sample(&mut self) -> Option { + pub fn get_sample(&mut self) -> Option { if !self.is_started { self.start(); } diff --git a/src/generator_files.rs b/src/generator_files.rs new file mode 100644 index 0000000..a0bf39c --- /dev/null +++ b/src/generator_files.rs @@ -0,0 +1,106 @@ +use serde::{Deserialize, Serialize}; +use std::hash::Hash; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SourceFileConfig { + pub root_path: String, +} + +// Hash function to be able to dispatch the samples to the correct rank +fn hash(t: &T) -> u64 { + use std::hash::Hasher; + let mut s = std::collections::hash_map::DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +pub fn ping_files( + pages_tx: kanal::Sender, + source_config: SourceFileConfig, + rank: usize, + world_size: usize, + num_samples: usize, +) { + // Get an iterator over the files in the root path + let supported_extensions = [ + "jpg", "jpeg", "png", "bmp", "gif", "JPG", "JPEG", "PNG", "BMP", "GIF", + ]; + + let files = walkdir::WalkDir::new(&source_config.root_path) + .follow_links(false) + .into_iter() + .filter_map(|e| e.ok()); + + let files = files.filter_map(|entry| { + let path = entry.path(); + let file_name = path.to_str().unwrap().to_string(); + if supported_extensions + .iter() + .any(|&ext| file_name.ends_with(ext)) + { + Some(entry) + } else { + None + } + }); + + let page_size = 500; + + // While we have something in the Send the samples to the channel + let mut count = 0; + let mut page = Vec::new(); + + // Build a page from the files iterator + for entry in files { + let path = entry.path(); + + let file_name = path.to_str().unwrap().to_string(); + + // If world_size is not 0, we need to dispatch the samples to the correct rank + if world_size > 1 { + let hash = hash(&file_name); + let target_rank = (hash % world_size as u64) as usize; + if target_rank != rank { + continue; + } + } + + page.push(file_name); + count += 1; + + if page.len() >= page_size || count >= num_samples { + // Convert the page to a JSON + let page = serde_json::json!({ + "results": page, + "rank": rank, + "world_size": world_size, + }); + + // Push the page to the channel + if pages_tx.send(page.clone()).is_err() { + println!("ping_pages: stream already closed, wrapping up"); + break; + } + } + + if count >= num_samples { + // NOTE: This doesn´t count the samples which have actually been processed + println!("ping_pages: reached the limit of samples requested. Shutting down"); + break; + } + } + + // Either we don't have any more samples or we have reached the limit + println!( + "ping_pages: total samples requested: {}. page samples served {}", + num_samples, count + ); + + // Send an empty value to signal the end of the stream + match pages_tx.send(serde_json::Value::Null) { + Ok(_) => {} + Err(_) => { + println!("ping_pages: stream already closed, all good"); + } + }; +} diff --git a/src/generator_http.rs b/src/generator_http.rs index 70c52ed..fb4fe29 100644 --- a/src/generator_http.rs +++ b/src/generator_http.rs @@ -63,7 +63,7 @@ pub struct SourceDBConfig { // TODO: Derive from the above #[derive(Debug, Serialize, Deserialize)] -pub struct DbRequest { +struct DbRequest { pub fields: String, pub sources: String, pub sources_ne: String, @@ -97,7 +97,7 @@ pub struct DbRequest { // implement a helper to get the http request which corresponds to the db request structure above impl DbRequest { - pub fn get_http_request(&self, api_url: &str, api_key: &str) -> reqwest::blocking::Request { + fn get_http_request(&self, api_url: &str, api_key: &str) -> reqwest::blocking::Request { let mut url = if self.random_sampling { Url::parse(&format!("{}images/random/", api_url)) } else { @@ -344,7 +344,7 @@ pub fn ping_pages( }; } -pub fn pull_pages( +pub fn dispatch_pages( pages_rx: kanal::Receiver, samples_meta_tx: kanal::Sender, num_samples: usize, @@ -354,7 +354,7 @@ pub fn pull_pages( while count < num_samples { match pages_rx.recv() { Ok(serde_json::Value::Null) => { - println!("pull_pages: end of stream received, stopping there"); + println!("dispatch_pages: end of stream received, stopping there"); break; } Ok(response_json) => { @@ -365,7 +365,7 @@ pub fn pull_pages( // Push the sample to the channel if samples_meta_tx.send(sample_json).is_err() { - println!("pull_pages: stream already closed, wrapping up"); + println!("dispatch_pages: stream already closed, wrapping up"); pages_rx.close(); break; } @@ -375,7 +375,7 @@ pub fn pull_pages( if count >= num_samples { // NOTE: This doesn´t count the samples which have actually been processed println!( - "pull_pages: reached the limit of samples requested. Shutting down" + "dispatch_pages: reached the limit of samples requested. Shutting down" ); break; } @@ -387,7 +387,7 @@ pub fn pull_pages( } } Err(_) => { - println!("pull_pages: stream already closed, wrapping up"); + println!("dispatch_pages: stream already closed, wrapping up"); break; } } @@ -397,7 +397,7 @@ pub fn pull_pages( // Either we don't have any more samples or we have reached the limit println!( - "pull_pages: total samples requested: {}. served {}", + "dispatch_pages: total samples requested: {}. served {}", num_samples, count ); diff --git a/src/image_processing.rs b/src/image_processing.rs index 24f24eb..143825e 100644 --- a/src/image_processing.rs +++ b/src/image_processing.rs @@ -108,12 +108,11 @@ impl ARAwareTransform { pub fn crop_and_resize( &self, image: &image::DynamicImage, - aspect_ratio_input: &String, + aspect_ratio_input: Option<&String>, ) -> image::DynamicImage { - let aspect_ratio = if aspect_ratio_input.is_empty() { - self.get_closest_aspect_ratio(image.width() as i32, image.height() as i32) - } else { - aspect_ratio_input.to_string() + let aspect_ratio = match aspect_ratio_input { + Some(ar) => ar.to_string(), + None => self.get_closest_aspect_ratio(image.width() as i32, image.height() as i32), }; if let Some(target_size) = self.aspect_ratio_to_size.get(&aspect_ratio) { @@ -156,15 +155,15 @@ mod tests { // Test image resizing let img = DynamicImage::new_rgb8(300, 200); - let resized = transform.crop_and_resize(&img, &"1.000".to_string()); + let resized = transform.crop_and_resize(&img, Some(&"1.000".to_string())); assert_eq!(resized.dimensions(), (224, 224)); - let resized = transform.crop_and_resize(&img, &"1.900".to_string()); + let resized = transform.crop_and_resize(&img, Some(&"1.900".to_string())); assert_eq!(resized.dimensions(), (304, 160)); // Test empty aspect ratio input (should use closest) let img = DynamicImage::new_rgb8(400, 200); - let resized = transform.crop_and_resize(&img, &"".to_string()); + let resized = transform.crop_and_resize(&img, None); assert_eq!(resized.dimensions(), (304, 160)); } @@ -176,7 +175,7 @@ mod tests { // Check all sizes respect min/max aspect ratio for (w, h) in sizes { let ar = w as f64 / h as f64; - assert!(ar >= 0.5 && ar <= 2.0); + assert!((0.5..=2.0).contains(&ar)); // Check dimensions are multiples of downsampling ratio assert_eq!(w % 16, 0); diff --git a/src/lib.rs b/src/lib.rs index 81d35e8..67ac89f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,16 @@ pub mod client; +pub mod generator_files; pub mod generator_http; pub mod image_processing; +pub mod structs; +pub mod worker_files; pub mod worker_http; pub use client::DatagoClient; +pub use generator_files::SourceFileConfig; pub use generator_http::SourceDBConfig; pub use image_processing::ImageTransformConfig; -pub use worker_http::{ImagePayload, LatentPayload, Sample}; +pub use structs::{DatagoClientConfig, ImagePayload, LatentPayload, Sample}; use pyo3::prelude::*; diff --git a/src/main.rs b/src/main.rs index 5645096..e40e027 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,8 +3,11 @@ use prettytable::{row, Table}; use serde_json::json; mod client; +mod generator_files; mod generator_http; mod image_processing; +mod structs; +mod worker_files; mod worker_http; fn main() { diff --git a/src/structs.rs b/src/structs.rs new file mode 100644 index 0000000..bcf804a --- /dev/null +++ b/src/structs.rs @@ -0,0 +1,109 @@ +use crate::image_processing::ImageTransformConfig; +use pyo3::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SourceType { + Db, + File, +} + +fn default_source_type() -> SourceType { + SourceType::Db +} + +#[derive(Deserialize)] +pub struct DatagoClientConfig { + #[serde(default = "default_source_type")] + pub source_type: SourceType, + + pub source_config: serde_json::Value, + pub image_config: Option, + pub limit: usize, + pub rank: usize, + pub world_size: usize, + pub samples_buffer_size: usize, +} + +#[pyclass] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LatentPayload { + #[pyo3(get, set)] + pub data: Vec, + #[pyo3(get, set)] + pub len: usize, +} + +#[pyclass] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ImagePayload { + #[pyo3(get, set)] + pub data: Vec, + #[pyo3(get, set)] + pub original_height: usize, // Good indicator of the image frequency dbResponse at the current resolution + #[pyo3(get, set)] + pub original_width: usize, + #[pyo3(get, set)] + pub height: usize, // Useful to decode the current payload + #[pyo3(get, set)] + pub width: usize, + #[pyo3(get, set)] + pub channels: i8, + #[pyo3(get, set)] + pub bit_depth: usize, +} + +#[pyclass] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Sample { + #[pyo3(get, set)] + pub id: String, + + #[pyo3(get, set)] + pub source: String, + + #[doc(hidden)] + pub attributes: std::collections::HashMap, + + #[pyo3(get, set)] + pub duplicate_state: i32, + + #[pyo3(get, set)] + pub image: ImagePayload, + + #[pyo3(get, set)] + pub masks: std::collections::HashMap, + + #[pyo3(get, set)] + pub additional_images: std::collections::HashMap, + + #[pyo3(get, set)] + pub latents: std::collections::HashMap, + + #[pyo3(get, set)] + pub coca_embedding: Vec, + + #[pyo3(get, set)] + pub tags: Vec, +} + +#[pymethods] +impl Sample { + #[getter] + pub fn attributes(&self) -> String { + serde_json::to_string(&self.attributes).unwrap_or("".to_string()) + } +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct CocaEmbedding { + pub vector: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UrlLatent { + pub file_direct_url: String, + pub latent_type: String, + pub is_mask: bool, +} diff --git a/src/worker_files.rs b/src/worker_files.rs new file mode 100644 index 0000000..63923b1 --- /dev/null +++ b/src/worker_files.rs @@ -0,0 +1,121 @@ +use crate::image_processing; +use crate::structs::{ImagePayload, Sample}; +use std::collections::HashMap; +use std::io::Cursor; +use std::sync::Arc; + +fn image_from_path(path: &str) -> Result { + // Load bytes from the file + let bytes = std::fs::read(path).map_err(|e| { + image::ImageError::IoError(std::io::Error::new(std::io::ErrorKind::Other, e)) + })?; + + // Decode the image + image::load_from_memory(&bytes) +} + +fn image_payload_from_path( + path: &str, + img_tfm: &Option, + encode_images: bool, +) -> Result { + match image_from_path(path) { + Ok(mut new_image) => { + let original_height = new_image.height() as usize; + let original_width = new_image.width() as usize; + let mut channels = new_image.color().channel_count() as i8; + let bit_depth = new_image.color().bits_per_pixel() as usize; + + // Optionally transform the additional image in the same way the main image was + if let Some(img_tfm) = img_tfm { + new_image = img_tfm.crop_and_resize(&new_image, None); + } + + let height = new_image.height() as usize; + let width = new_image.width() as usize; + + // Encode the image if needed + let mut image_bytes: Vec = Vec::new(); + if encode_images { + if new_image + .write_to(&mut Cursor::new(&mut image_bytes), image::ImageFormat::Png) + .is_err() + { + return Err(image::ImageError::IoError(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to encode image", + ))); + } + + channels = -1; // Signal the fact that the image is encoded + } else { + image_bytes = new_image.into_bytes(); + } + + Ok(ImagePayload { + data: image_bytes, + original_height, + original_width, + height, + width, + channels, + bit_depth, + }) + } + Err(e) => Err(e), + } +} + +fn pull_sample( + sample_json: &serde_json::Value, + img_tfm: &Option, + encode_images: bool, +) -> Option { + let image_payload = + image_payload_from_path(sample_json.as_str().unwrap(), img_tfm, encode_images); + + if let Ok(image) = image_payload { + Some(Sample { + id: sample_json.to_string(), + source: "filesystem".to_string(), + image, + attributes: HashMap::new(), + coca_embedding: vec![], + tags: vec![], + masks: HashMap::new(), + latents: HashMap::new(), + additional_images: HashMap::new(), + duplicate_state: 0, + }) + } else { + println!("Failed to load image from path {}", sample_json); + None + } +} + +pub fn pull_samples( + samples_meta_rx: kanal::Receiver, + samples_tx: kanal::Sender, + worker_done_count: Arc, + image_transform: &Option, + encode_images: bool, +) { + while let Ok(received) = samples_meta_rx.recv() { + if received == serde_json::Value::Null { + println!("http_worker: end of stream received, stopping there"); + samples_meta_rx.close(); + break; + } + + if let Some(sample) = pull_sample(&received, image_transform, encode_images) { + if samples_tx.send(sample).is_err() { + println!("http_worker: failed to send a sample"); + break; + } + } else { + break; + } + } + + worker_done_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); +} diff --git a/src/worker_http.rs b/src/worker_http.rs index e7756b1..9a99a9f 100644 --- a/src/worker_http.rs +++ b/src/worker_http.rs @@ -1,8 +1,6 @@ use crate::image_processing; -use pyo3::pyclass; -use pyo3::pymethods; -use serde::Deserialize; -use serde::Serialize; +use crate::structs::{CocaEmbedding, ImagePayload, LatentPayload, Sample, UrlLatent}; +use serde::{Deserialize, Serialize}; use std::io::Cursor; use std::sync::Arc; @@ -14,88 +12,6 @@ pub struct SharedClient { } // ------------------------------------------------------------------ -#[pyclass] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct LatentPayload { - #[pyo3(get, set)] - data: Vec, - #[pyo3(get, set)] - len: usize, -} - -#[pyclass] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct ImagePayload { - #[pyo3(get, set)] - pub data: Vec, - #[pyo3(get, set)] - pub original_height: usize, // Good indicator of the image frequency dbResponse at the current resolution - #[pyo3(get, set)] - pub original_width: usize, - #[pyo3(get, set)] - pub height: usize, // Useful to decode the current payload - #[pyo3(get, set)] - pub width: usize, - #[pyo3(get, set)] - pub channels: i8, - #[pyo3(get, set)] - pub bit_depth: usize, -} - -#[pyclass] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Sample { - #[pyo3(get, set)] - pub id: String, - - #[pyo3(get, set)] - pub source: String, - - #[doc(hidden)] - pub attributes: std::collections::HashMap, - - #[pyo3(get, set)] - pub duplicate_state: i32, - - #[pyo3(get, set)] - pub image: ImagePayload, - - #[pyo3(get, set)] - pub masks: std::collections::HashMap, - - #[pyo3(get, set)] - pub additional_images: std::collections::HashMap, - - #[pyo3(get, set)] - pub latents: std::collections::HashMap, - - #[pyo3(get, set)] - pub coca_embedding: Vec, - - #[pyo3(get, set)] - pub tags: Vec, -} - -#[pymethods] -impl Sample { - #[getter] - pub fn attributes(&self) -> String { - serde_json::to_string(&self.attributes).unwrap_or("".to_string()) - } -} - -#[derive(Debug, Serialize, Deserialize, Default)] -struct CocaEmbedding { - vector: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct UrlLatent { - file_direct_url: String, - latent_type: String, - is_mask: bool, -} - #[derive(Debug, Serialize, Deserialize)] struct SampleMetadata { id: String, @@ -158,7 +74,13 @@ fn image_payload_from_url( // Optionally transform the additional image in the same way the main image was if let Some(img_tfm) = img_tfm { - new_image = img_tfm.crop_and_resize(&new_image, aspect_ratio); + let aspect_ratio_input = if aspect_ratio.is_empty() { + None + } else { + Some(aspect_ratio) + }; + + new_image = img_tfm.crop_and_resize(&new_image, aspect_ratio_input); } let height = new_image.height() as usize;