Skip to content

Commit

Permalink
chore: update dependencies and enhance Redis checkpointing tests
Browse files Browse the repository at this point in the history
- Refactor and expand unit tests for Redis checkpointing functions to cover various edge cases and improve test coverage.
- Add detailed docstrings and parameterized tests for better clarity and maintainability.
  • Loading branch information
muralov committed Jan 22, 2025
1 parent a45f545 commit 7588a1d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 34 deletions.
26 changes: 13 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ scrubadub = {extras = ["all"], version = "^2.0.1"}
tiktoken = "^0.7.0"

[tool.poetry.group.test.dependencies]
deepeval = "^2.1.2"
deepeval = "^2.2.0"
fakeredis = "^2.23.3"
prettytable = "^3.10.2"
pytest = "^8.2.2"
Expand Down
3 changes: 2 additions & 1 deletion src/agents/memory/async_redis_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def _make_redis_checkpoint_key(
thread_id: str, checkpoint_ns: str, checkpoint_id: str
) -> str:
"""Create a Redis key for storing checkpoint data.
Returns a Redis key string in the format "checkpoint$thread_id$namespace$checkpoint_id".
"""
return REDIS_KEY_SEPARATOR.join(
Expand Down Expand Up @@ -271,6 +270,8 @@ async def aput_writes(
"""Store intermediate writes linked to a checkpoint asynchronously.
This method saves intermediate writes associated with a checkpoint to the database.
Critical for fault tolerance: stores successful node outputs even if other nodes
in the same superstep fail, preventing unnecessary re-execution on retry.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
Expand Down
193 changes: 174 additions & 19 deletions tests/unit/agents/memory/test_async_redis_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,30 +297,185 @@ async def test_aload_pending_writes(


class TestUtilityFunctions:
def test_make_redis_checkpoint_key(self):
key = _make_redis_checkpoint_key("thread1", "ns1", "chk1")
assert key == "checkpoint$thread1$ns1$chk1"
@pytest.mark.parametrize(
"thread_id, checkpoint_ns, checkpoint_id, expected_key",
[
# Basic case
("thread1", "ns1", "chk1", "checkpoint$thread1$ns1$chk1"),
# Special characters
("thread-1", "ns/1", "chk.1", "checkpoint$thread-1$ns/1$chk.1"),
# Long identifiers
(
"thread_very_long_123",
"namespace_very_long_456",
"checkpoint_very_long_789",
"checkpoint$thread_very_long_123$namespace_very_long_456$checkpoint_very_long_789",
),
# Empty namespace (should be filtered out)
("thread1", "", "chk1", "checkpoint$thread1$$chk1"),
# Empty thread_id
("", "ns1", "chk1", "checkpoint$$ns1$chk1"),
# Empty checkpoint_id
("thread1", "ns1", "", "checkpoint$thread1$ns1$"),
],
)
def test_make_redis_checkpoint_key(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
expected_key: str,
) -> None:
"""Test the _make_redis_checkpoint_key function with various inputs."""

def test_make_redis_checkpoint_writes_key(self):
key = _make_redis_checkpoint_writes_key("thread1", "ns1", "chk1", "task1", 0)
assert key == "writes$thread1$ns1$chk1$task1$0"
key = _make_redis_checkpoint_key(thread_id, checkpoint_ns, checkpoint_id)
assert key == expected_key

key_no_idx = _make_redis_checkpoint_writes_key(
"thread1", "ns1", "chk1", "task1", None
@pytest.mark.parametrize(
"thread_id, checkpoint_ns, checkpoint_id, task_id, idx, expected_key",
[
# Basic case with index
("thread1", "ns1", "chk1", "task1", 0, "writes$thread1$ns1$chk1$task1$0"),
# Case without index (None)
("thread1", "ns1", "chk1", "task1", None, "writes$thread1$ns1$chk1$task1"),
# Special characters in identifiers
(
"thread-1",
"ns/1",
"chk.1",
"task:1",
0,
"writes$thread-1$ns/1$chk.1$task:1$0",
),
# Long identifiers
(
"thread_very_long_123",
"namespace_very_long_456",
"checkpoint_very_long_789",
"task_very_long_012",
1,
"writes$thread_very_long_123$namespace_very_long_456$checkpoint_very_long_789$task_very_long_012$1",
),
# Empty namespace
("thread1", "", "chk1", "task1", 0, "writes$thread1$$chk1$task1$0"),
# Empty thread_id
("", "ns1", "chk1", "task1", 0, "writes$$ns1$chk1$task1$0"),
# Negative index
("thread1", "ns1", "chk1", "task1", -1, "writes$thread1$ns1$chk1$task1$-1"),
# Large index
(
"thread1",
"ns1",
"chk1",
"task1",
999999,
"writes$thread1$ns1$chk1$task1$999999",
),
],
)
def test_make_redis_checkpoint_writes_key(
self,
thread_id: str,
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
idx: int | None,
expected_key: str,
) -> None:
"""Test the _make_redis_checkpoint_writes_key function with various inputs."""
key = _make_redis_checkpoint_writes_key(
thread_id, checkpoint_ns, checkpoint_id, task_id, idx
)
assert key_no_idx == "writes$thread1$ns1$chk1$task1"
assert key == expected_key

def test_parse_redis_checkpoint_key(self):
key = "checkpoint$thread1$ns1$chk1"
result = _parse_redis_checkpoint_key(key)
assert result == {
"thread_id": "thread1",
"checkpoint_ns": "ns1",
"checkpoint_id": "chk1",
}
@pytest.mark.parametrize(
"key, expected_result, should_raise",
[
# Valid cases
(
"checkpoint$thread1$ns1$chk1",
{
"thread_id": "thread1",
"checkpoint_ns": "ns1",
"checkpoint_id": "chk1",
},
False,
),
# Special characters in identifiers
(
"checkpoint$thread-1$ns/1$chk.1",
{
"thread_id": "thread-1",
"checkpoint_ns": "ns/1",
"checkpoint_id": "chk.1",
},
False,
),
# Empty namespace
(
"checkpoint$thread1$$chk1",
{
"thread_id": "thread1",
"checkpoint_ns": "",
"checkpoint_id": "chk1",
},
False,
),
# Long identifiers
(
"checkpoint$thread_very_long_123$namespace_very_long_456$checkpoint_very_long_789",
{
"thread_id": "thread_very_long_123",
"checkpoint_ns": "namespace_very_long_456",
"checkpoint_id": "checkpoint_very_long_789",
},
False,
),
# UUID-like identifiers
(
"checkpoint$550e8400-e29b-41d4-a716-446655440000$ns1$chk1",
{
"thread_id": "550e8400-e29b-41d4-a716-446655440000",
"checkpoint_ns": "ns1",
"checkpoint_id": "chk1",
},
False,
),
# Invalid cases - wrong prefix
("invalid$thread1$ns1$chk1", None, True),
# Invalid cases - wrong number of segments
("checkpoint$thread1$ns1", None, True),
("checkpoint$thread1$ns1$chk1$extra", None, True),
# Invalid cases - empty segments
(
"checkpoint$$$",
{
"thread_id": "",
"checkpoint_ns": "",
"checkpoint_id": "",
},
False,
),
# Invalid cases - completely wrong format
("invalid_key", None, True),
# Invalid cases - missing separator
("checkpointthread1ns1chk1", None, True),
],
)
def test_parse_redis_checkpoint_key(self, key, expected_result, should_raise):
"""Test _parse_redis_checkpoint_key with various inputs using table-driven tests.
with pytest.raises(ValueError):
_parse_redis_checkpoint_key("invalid$key")
Args:
key: Input Redis key to parse
expected_result: Expected dictionary output for valid keys
should_raise: Whether ValueError should be raised
"""
if should_raise:
with pytest.raises(ValueError):
_parse_redis_checkpoint_key(key)
else:
result = _parse_redis_checkpoint_key(key)
assert result == expected_result

def test_parse_redis_checkpoint_writes_key(self):
key = "writes$thread1$ns1$chk1$task1$0"
Expand Down

0 comments on commit 7588a1d

Please sign in to comment.