Skip to content

Commit

Permalink
Set min_length_time_axis correctly in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen committed Dec 10, 2024
1 parent 1352bfa commit 23dde3b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions flashbax/buffers/mixer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_mixed_trajectory_sample(
for i in range(3):
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=200 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_mixed_prioritised_trajectory_sample(
for i in range(3):
buffer = prioritised_trajectory_buffer.make_prioritised_trajectory_buffer(
max_length_time_axis=200 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_mixed_flat_buffer_sample(
for i in range(3):
buffer = flat_buffer.make_flat_buffer(
max_length=200 * (i + 1),
min_length=0,
min_length=add_batch_size,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
add_sequences=True,
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_mixed_buffer_does_not_smoke(
for i in range(3):
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=2000 * (i + 1),
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down
4 changes: 2 additions & 2 deletions flashbax/buffers/trajectory_buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_add_sample_max_capacity(
sample_sequence_length = add_sequence_length
buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=add_sequence_length,
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_uniform_index_cal(

buffer = trajectory_buffer.make_trajectory_buffer(
max_length_time_axis=max_length,
min_length_time_axis=0,
min_length_time_axis=sample_sequence_length,
sample_batch_size=sample_batch_size,
add_batch_size=add_batch_size,
sample_sequence_length=sample_sequence_length,
Expand Down

0 comments on commit 23dde3b

Please sign in to comment.