diff --git a/flashbax/buffers/flat_buffer_test.py b/flashbax/buffers/flat_buffer_test.py index 3692c9b..e541a78 100644 --- a/flashbax/buffers/flat_buffer_test.py +++ b/flashbax/buffers/flat_buffer_test.py @@ -78,8 +78,9 @@ def test_sample( # Fill buffer to the point that we can sample fake_batch = get_fake_batch(fake_transition, int(min_length + 10)) + add_batch_size = int(min_length + 10) buffer = flat_buffer.make_flat_buffer( - max_length, min_length, sample_batch_size, False, int(min_length + 10) + max_length, add_batch_size, sample_batch_size, False, add_batch_size ) state = buffer.init(fake_transition) @@ -224,7 +225,7 @@ def test_flat_replay_buffer_does_not_smoke( add_batch_size = int(min_length + 5) buffer = flat_buffer.make_flat_buffer( - max_length, min_length, sample_batch_size, False, add_batch_size + max_length, add_batch_size, sample_batch_size, False, add_batch_size ) # Initialise the buffer's state. diff --git a/flashbax/buffers/prioritised_flat_buffer_test.py b/flashbax/buffers/prioritised_flat_buffer_test.py index e2ec1de..a48f0f7 100644 --- a/flashbax/buffers/prioritised_flat_buffer_test.py +++ b/flashbax/buffers/prioritised_flat_buffer_test.py @@ -90,7 +90,7 @@ def test_sample( fake_batch = get_fake_batch(fake_transition, add_batch_size) buffer = prioritised_flat_buffer.make_prioritised_flat_buffer( - max_length, min_length, sample_batch_size, False, add_batch_size + max_length, add_batch_size, sample_batch_size, False, add_batch_size ) state = buffer.init(fake_transition) @@ -137,7 +137,7 @@ def test_adjust_priorities( fake_batch = get_fake_batch(fake_transition, add_batch_size) buffer = prioritised_flat_buffer.make_prioritised_flat_buffer( max_length, - min_length, + add_batch_size, sample_batch_size, False, add_batch_size, @@ -179,7 +179,7 @@ def test_prioritised_flat_buffer_does_not_smoke( buffer = prioritised_flat_buffer.make_prioritised_flat_buffer( max_length, - min_length, + add_batch_size, sample_batch_size, False, add_batch_size, @@ -313,7 +313,6 @@ def test_add_sequences( def test_add_sequences_and_batches( fake_transition: chex.ArrayTree, - min_length: int, max_length: int, add_batch_size: int, sample_batch_size: int, @@ -329,7 +328,7 @@ def test_add_sequences_and_batches( buffer = prioritised_flat_buffer.make_prioritised_flat_buffer( max_length, - min_length, + add_batch_size, sample_batch_size, add_sequences=True, add_batch_size=add_batch_size, diff --git a/pyproject.toml b/pyproject.toml index 5a98cb2..580ae0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ include = ["flashbax/*"] [tool.pytest.ini_options] filterwarnings = [ "error", - "ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax", "ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax", "ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax", ]