Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address sample_sequence_length greater than min_length_time_axis #45

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mickvangelderen
Copy link
Contributor

@mickvangelderen mickvangelderen commented Nov 9, 2024

Based on #43, please review that first and if it wasn't merged yet, please review this PR per commit.

The goal of this PR is to make the tests not generate flashbax warnings and to prevent code from being added that would trigger a flashbax warning.

I am not sure that the suggested changes in the last commit are right.

I am unsure why min_length_time_axis=min_length // add_batch_size + 1 is used in the trajectory buffer source code, particularly the + 1. I am also unsure why the min_length fixture is defined as:

@pytest.fixture
def min_length(sample_batch_size: int) -> int:
    return int(sample_batch_size + 1)

I could understand if it was defined in terms of add_batch_size, but it is not.

@SimonDuToit
Copy link
Contributor

SimonDuToit commented Dec 10, 2024

I think the + 1 in min_length_time_axis=min_length // add_batch_size + 1 is just to be conservative. This matches how max_length_time_axis is calculated.
The min_length fixture is defined to be viable, since it has to be at least as big as sample_size.

Copy link
Contributor

@SimonDuToit SimonDuToit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why its necessary to replace min_length with add_batch_size. Otherwise it makes sense, please just see my comment on the merge conflict.

flashbax/buffers/flat_buffer_test.py Outdated Show resolved Hide resolved
pyproject.toml Show resolved Hide resolved
flashbax/buffers/mixer_test.py Show resolved Hide resolved
@mickvangelderen mickvangelderen force-pushed the mick/fix-sample-seq-len-gt-min-len-time branch from c485ecf to 082261d Compare December 10, 2024 17:51
Copy link
Contributor

@SimonDuToit SimonDuToit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just two tests still failing. You forgot to change a line on one, and another one you changed unnecessarily.

@@ -79,7 +79,7 @@ def test_sample(
fake_batch = get_fake_batch(fake_transition, int(min_length + 10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fake_batch = get_fake_batch(fake_transition, int(min_length + 10))
fake_batch = get_fake_batch(fake_transition, int(min_length))


for i in range(n_sequences_to_fill):
assert not state.is_full
state = buffer.add(state, fake_batch)
assert state.current_index == (
((i + 1) * add_sequence_size) % (max_length // add_batch_size)
((i + 1) * add_sequence_size) % (max_length // min_length)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please undo the changes to this test, test_add_sequences_and_batches. The problem you found only comes up when add_batch_size is manually set to something bigger than min_length, but here it just uses the argument which is fine.

@mickvangelderen mickvangelderen force-pushed the mick/fix-sample-seq-len-gt-min-len-time branch from 6abfd3d to 9f3ecec Compare January 8, 2025 23:29
@mickvangelderen mickvangelderen force-pushed the mick/fix-sample-seq-len-gt-min-len-time branch from 9f3ecec to 7ba5b38 Compare January 9, 2025 17:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants