From 23dde3be17b70c613825754cf06a9e4462d6e324 Mon Sep 17 00:00:00 2001 From: Mick van Gelderen Date: Fri, 8 Nov 2024 14:09:42 -0800 Subject: [PATCH] Set min_length_time_axis correctly in tests --- flashbax/buffers/mixer_test.py | 8 ++++---- flashbax/buffers/trajectory_buffer_test.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flashbax/buffers/mixer_test.py b/flashbax/buffers/mixer_test.py index bfcf6f2..d0a1f63 100644 --- a/flashbax/buffers/mixer_test.py +++ b/flashbax/buffers/mixer_test.py @@ -92,7 +92,7 @@ def test_mixed_trajectory_sample( for i in range(3): buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=200 * (i + 1), - min_length_time_axis=0, + min_length_time_axis=sample_sequence_length, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, sample_sequence_length=sample_sequence_length, @@ -143,7 +143,7 @@ def test_mixed_prioritised_trajectory_sample( for i in range(3): buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer( max_length_time_axis=200 * (i + 1), - min_length_time_axis=0, + min_length_time_axis=sample_sequence_length, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, sample_sequence_length=sample_sequence_length, @@ -192,7 +192,7 @@ def test_mixed_flat_buffer_sample( for i in range(3): buffer = flat_buffer.make_flat_buffer( max_length=200 * (i + 1), - min_length=0, + min_length=add_batch_size, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, add_sequences=True, @@ -241,7 +241,7 @@ def test_mixed_buffer_does_not_smoke( for i in range(3): buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=2000 * (i + 1), - min_length_time_axis=0, + min_length_time_axis=sample_sequence_length, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, sample_sequence_length=sample_sequence_length, diff --git a/flashbax/buffers/trajectory_buffer_test.py b/flashbax/buffers/trajectory_buffer_test.py index 908bb33..f4a9c58 100644 --- a/flashbax/buffers/trajectory_buffer_test.py +++ b/flashbax/buffers/trajectory_buffer_test.py @@ -165,7 +165,7 @@ def test_add_sample_max_capacity( sample_sequence_length = add_sequence_length buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=add_sequence_length, - min_length_time_axis=0, + min_length_time_axis=sample_sequence_length, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, sample_sequence_length=sample_sequence_length, @@ -342,7 +342,7 @@ def test_uniform_index_cal( buffer = trajectory_buffer.make_trajectory_buffer( max_length_time_axis=max_length, - min_length_time_axis=0, + min_length_time_axis=sample_sequence_length, sample_batch_size=sample_batch_size, add_batch_size=add_batch_size, sample_sequence_length=sample_sequence_length,