Skip to content

Commit

Permalink
Fix remaining sample seq len > min len time axis
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen committed Dec 10, 2024
1 parent 23dde3b commit 082261d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
5 changes: 3 additions & 2 deletions flashbax/buffers/flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def test_sample(
# Fill buffer to the point that we can sample
fake_batch = get_fake_batch(fake_transition, int(min_length + 10))

add_batch_size = int(min_length + 10)
buffer = flat_buffer.make_flat_buffer(
max_length, min_length, sample_batch_size, False, int(min_length + 10)
max_length, add_batch_size, sample_batch_size, False, add_batch_size
)
state = buffer.init(fake_transition)

Expand Down Expand Up @@ -224,7 +225,7 @@ def test_flat_replay_buffer_does_not_smoke(

add_batch_size = int(min_length + 5)
buffer = flat_buffer.make_flat_buffer(
max_length, min_length, sample_batch_size, False, add_batch_size
max_length, add_batch_size, sample_batch_size, False, add_batch_size
)

# Initialise the buffer's state.
Expand Down
9 changes: 4 additions & 5 deletions flashbax/buffers/prioritised_flat_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_sample(
fake_batch = get_fake_batch(fake_transition, add_batch_size)

buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length, min_length, sample_batch_size, False, add_batch_size
max_length, add_batch_size, sample_batch_size, False, add_batch_size
)
state = buffer.init(fake_transition)

Expand Down Expand Up @@ -137,7 +137,7 @@ def test_adjust_priorities(
fake_batch = get_fake_batch(fake_transition, add_batch_size)
buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length,
min_length,
add_batch_size,
sample_batch_size,
False,
add_batch_size,
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_prioritised_flat_buffer_does_not_smoke(

buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length,
min_length,
add_batch_size,
sample_batch_size,
False,
add_batch_size,
Expand Down Expand Up @@ -313,7 +313,6 @@ def test_add_sequences(

def test_add_sequences_and_batches(
fake_transition: chex.ArrayTree,
min_length: int,
max_length: int,
add_batch_size: int,
sample_batch_size: int,
Expand All @@ -329,7 +328,7 @@ def test_add_sequences_and_batches(

buffer = prioritised_flat_buffer.make_prioritised_flat_buffer(
max_length,
min_length,
add_batch_size,
sample_batch_size,
add_sequences=True,
add_batch_size=add_batch_size,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ include = ["flashbax/*"]
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
"ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
"ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
]
Expand Down

0 comments on commit 082261d

Please sign in to comment.