diff --git a/dapr/actor/runtime/_state_provider.py b/dapr/actor/runtime/_state_provider.py index 7fc78ce7..54f6b583 100644 --- a/dapr/actor/runtime/_state_provider.py +++ b/dapr/actor/runtime/_state_provider.py @@ -66,7 +66,10 @@ async def save_state( "operation": "upsert", "request": { "key": "key1", - "value": "myData" + "value": "myData", + "metadata": { + "ttlInSeconds": "3600" + } } }, { @@ -94,6 +97,11 @@ async def save_state( serialized = self._state_serializer.serialize(state.value) json_output.write(b',"value":') json_output.write(serialized) + if state.ttl_in_seconds is not None and state.ttl_in_seconds >= 0: + serialized = self._state_serializer.serialize(state.ttl_in_seconds) + json_output.write(b',"metadata":{"ttlInSeconds":"') + json_output.write(serialized) + json_output.write(b'"}') json_output.write(b'}}') first_state = False json_output.write(b']') diff --git a/dapr/actor/runtime/state_change.py b/dapr/actor/runtime/state_change.py index 42381eae..dba21e2c 100644 --- a/dapr/actor/runtime/state_change.py +++ b/dapr/actor/runtime/state_change.py @@ -14,7 +14,7 @@ """ from enum import Enum -from typing import TypeVar, Generic +from typing import TypeVar, Generic, Optional T = TypeVar('T') @@ -35,10 +35,17 @@ class StateChangeKind(Enum): class ActorStateChange(Generic[T]): - def __init__(self, state_name: str, value: T, change_kind: StateChangeKind): + def __init__( + self, + state_name: str, + value: T, + change_kind: StateChangeKind, + ttl_in_seconds: Optional[int] = None, + ): self._state_name = state_name self._value = value self._change_kind = change_kind + self._ttl_in_seconds = ttl_in_seconds @property def state_name(self) -> str: @@ -51,3 +58,7 @@ def value(self) -> T: @property def change_kind(self) -> StateChangeKind: return self._change_kind + + @property + def ttl_in_seconds(self) -> Optional[int]: + return self._ttl_in_seconds diff --git a/dapr/actor/runtime/state_manager.py b/dapr/actor/runtime/state_manager.py index 52313f77..7132175b 100644 --- a/dapr/actor/runtime/state_manager.py +++ b/dapr/actor/runtime/state_manager.py @@ -29,9 +29,12 @@ class StateMetadata(Generic[T]): - def __init__(self, value: T, change_kind: StateChangeKind): + def __init__( + self, value: T, change_kind: StateChangeKind, ttl_in_seconds: Optional[int] = None + ): self._value = value self._change_kind = change_kind + self._ttl_in_seconds = ttl_in_seconds @property def value(self) -> T: @@ -49,6 +52,14 @@ def change_kind(self) -> StateChangeKind: def change_kind(self, new_kind: StateChangeKind) -> None: self._change_kind = new_kind + @property + def ttl_in_seconds(self) -> Optional[int]: + return self._ttl_in_seconds + + @ttl_in_seconds.setter + def ttl_in_seconds(self, new_ttl_in_seconds: int) -> None: + self._ttl_in_seconds = new_ttl_in_seconds + class ActorStateManager(Generic[T]): def __init__(self, actor: 'Actor'): @@ -103,10 +114,17 @@ async def try_get_state(self, state_name: str) -> Tuple[bool, Optional[T]]: return has_value, val async def set_state(self, state_name: str, value: T) -> None: + await self.set_state_ttl(state_name, value, None) + + async def set_state_ttl(self, state_name: str, value: T, ttl_in_seconds: Optional[int]) -> None: + if ttl_in_seconds is not None and ttl_in_seconds < 0: + return + state_change_tracker = self._get_contextual_state_tracker() if state_name in state_change_tracker: state_metadata = state_change_tracker[state_name] state_metadata.value = value + state_metadata.ttl_in_seconds = ttl_in_seconds if ( state_metadata.change_kind == StateChangeKind.none @@ -120,9 +138,13 @@ async def set_state(self, state_name: str, value: T) -> None: self._type_name, self._actor.id.id, state_name ) if existed: - state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.update) + state_change_tracker[state_name] = StateMetadata( + value, StateChangeKind.update, ttl_in_seconds + ) else: - state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add) + state_change_tracker[state_name] = StateMetadata( + value, StateChangeKind.add, ttl_in_seconds + ) async def remove_state(self, state_name: str) -> None: if not await self.try_remove_state(state_name): @@ -231,7 +253,12 @@ async def save_state(self) -> None: if state_metadata.change_kind == StateChangeKind.none: continue state_changes.append( - ActorStateChange(state_name, state_metadata.value, state_metadata.change_kind) + ActorStateChange( + state_name, + state_metadata.value, + state_metadata.change_kind, + state_metadata.ttl_in_seconds, + ) ) if state_metadata.change_kind == StateChangeKind.remove: states_to_remove.append(state_name) diff --git a/dev-requirements.txt b/dev-requirements.txt index 9bcf6fe9..15866725 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -13,6 +13,6 @@ pyOpenSSL>=23.2.0 # needed for type checking Flask>=1.1 # needed for auto fix -ruff===0.4.2 +ruff===0.2.2 # needed for dapr-ext-workflow durabletask>=0.1.1a1 diff --git a/tests/actor/test_state_manager.py b/tests/actor/test_state_manager.py index dfaf46bb..98db0228 100644 --- a/tests/actor/test_state_manager.py +++ b/tests/actor/test_state_manager.py @@ -116,6 +116,7 @@ def test_set_state_for_new_state(self): state = state_change_tracker['state1'] self.assertEqual(StateChangeKind.add, state.change_kind) self.assertEqual('value1', state.value) + self.assertEqual(None, state.ttl_in_seconds) @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) def test_set_state_for_existing_state_only_in_mem(self): @@ -131,6 +132,7 @@ def test_set_state_for_existing_state_only_in_mem(self): state = state_change_tracker['state1'] self.assertEqual(StateChangeKind.add, state.change_kind) self.assertEqual('value2', state.value) + self.assertEqual(None, state.ttl_in_seconds) @mock.patch( 'tests.actor.fake_client.FakeDaprActorClient.get_state', @@ -143,6 +145,73 @@ def test_set_state_for_existing_state(self): state = state_change_tracker['state1'] self.assertEqual(StateChangeKind.update, state.change_kind) self.assertEqual('value2', state.value) + self.assertEqual(None, state.ttl_in_seconds) + + @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + def test_set_state_ttl_for_new_state(self): + state_manager = ActorStateManager(self._fake_actor) + state_change_tracker = state_manager._get_contextual_state_tracker() + _run(state_manager.set_state_ttl('state1', 'value1', 3600)) + + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.add, state.change_kind) + self.assertEqual('value1', state.value) + self.assertEqual(3600, state.ttl_in_seconds) + + @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + def test_set_state_ttl_for_existing_state_only_in_mem(self): + state_manager = ActorStateManager(self._fake_actor) + state_change_tracker = state_manager._get_contextual_state_tracker() + _run(state_manager.set_state_ttl('state1', 'value1', 3600)) + + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.add, state.change_kind) + self.assertEqual('value1', state.value) + self.assertEqual(3600, state.ttl_in_seconds) + + _run(state_manager.set_state_ttl('state1', 'value2', 7200)) + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.add, state.change_kind) + self.assertEqual('value2', state.value) + self.assertEqual(7200, state.ttl_in_seconds) + + @mock.patch( + 'tests.actor.fake_client.FakeDaprActorClient.get_state', + new=_async_mock(return_value=b'"value1"'), + ) + def test_set_state_ttl_for_existing_state(self): + state_manager = ActorStateManager(self._fake_actor) + state_change_tracker = state_manager._get_contextual_state_tracker() + _run(state_manager.set_state_ttl('state1', 'value2', 3600)) + + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.update, state.change_kind) + self.assertEqual('value2', state.value) + self.assertEqual(3600, state.ttl_in_seconds) + + @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + def test_set_state_ttl_lt_0_for_new_state(self): + state_manager = ActorStateManager(self._fake_actor) + state_change_tracker = state_manager._get_contextual_state_tracker() + _run(state_manager.set_state_ttl('state1', 'value1', -3600)) + self.assertNotIn('state1', state_change_tracker) + + @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) + def test_set_state_ttl_lt_0_for_existing_state_only_in_mem(self): + state_manager = ActorStateManager(self._fake_actor) + state_change_tracker = state_manager._get_contextual_state_tracker() + _run(state_manager.set_state_ttl('state1', 'value1', 3600)) + + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.add, state.change_kind) + self.assertEqual('value1', state.value) + self.assertEqual(3600, state.ttl_in_seconds) + + _run(state_manager.set_state_ttl('state1', 'value2', -3600)) + state = state_change_tracker['state1'] + self.assertEqual(StateChangeKind.add, state.change_kind) + self.assertEqual('value1', state.value) + self.assertEqual(3600, state.ttl_in_seconds) @mock.patch('tests.actor.fake_client.FakeDaprActorClient.get_state', new=_async_mock()) def test_remove_state_for_non_existing_state(self): @@ -360,13 +429,20 @@ def test_save_state(self): _run(state_manager.remove_state('state4')) # set state which is StateChangeKind.update _run(state_manager.set_state('state5', 'value5')) - expected = b'[{"operation":"upsert","request":{"key":"state1","value":"value1"}},{"operation":"upsert","request":{"key":"state2","value":"value2"}},{"operation":"delete","request":{"key":"state4"}},{"operation":"upsert","request":{"key":"state5","value":"value5"}}]' # noqa: E501 + _run(state_manager.set_state('state5', 'new_value5')) + # set state with ttl >= 0 + _run(state_manager.set_state_ttl('state6', 'value6', 3600)) + _run(state_manager.set_state_ttl('state7', 'value7', 0)) + # set state with ttl < 0 + _run(state_manager.set_state_ttl('state8', 'value8', -3600)) + + expected = b'[{"operation":"upsert","request":{"key":"state1","value":"value1"}},{"operation":"upsert","request":{"key":"state2","value":"value2"}},{"operation":"delete","request":{"key":"state4"}},{"operation":"upsert","request":{"key":"state5","value":"new_value5"}},{"operation":"upsert","request":{"key":"state6","value":"value6","metadata":{"ttlInSeconds":"3600"}}},{"operation":"upsert","request":{"key":"state7","value":"value7","metadata":{"ttlInSeconds":"0"}}}]' # noqa: E501 # Save the state - def mock_save_state(actor_type, actor_id, data): + async def mock_save_state(actor_type, actor_id, data): self.assertEqual(expected, data) - self._fake_client.save_state_transactionally.mock = mock_save_state + self._fake_client.save_state_transactionally = mock_save_state _run(state_manager.save_state())