diff --git a/dapr/aio/clients/grpc/client.py b/dapr/aio/clients/grpc/client.py index 498bed9a..44578c73 100644 --- a/dapr/aio/clients/grpc/client.py +++ b/dapr/aio/clients/grpc/client.py @@ -36,9 +36,10 @@ UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor, + AioRpcError, ) -from dapr.clients.exceptions import DaprInternalError +from dapr.clients.exceptions import DaprInternalError, DaprGrpcError from dapr.clients.grpc._state import StateOptions, StateItem from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus from dapr.conf.helpers import GrpcEndpoint @@ -446,9 +447,12 @@ async def publish_event( metadata=publish_metadata, ) - call = self._stub.PublishEvent(req, metadata=metadata) - # response is google.protobuf.Empty - await call + try: + call = self._stub.PublishEvent(req, metadata=metadata) + # response is google.protobuf.Empty + await call + except AioRpcError as err: + raise DaprGrpcError(err) from err return DaprResponse(await call.initial_metadata()) @@ -491,9 +495,15 @@ async def get_state( if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError('State store name cannot be empty') + req = api_v1.GetStateRequest(store_name=store_name, key=key, metadata=state_metadata) - call = self._stub.GetState(req, metadata=metadata) - response = await call + + try: + call = self._stub.GetState(req, metadata=metadata) + response = await call + except AioRpcError as err: + raise DaprGrpcError(err) from err + return StateResponse( data=response.data, etag=response.etag, headers=await call.initial_metadata() ) @@ -542,8 +552,12 @@ async def get_bulk_state( req = api_v1.GetBulkStateRequest( store_name=store_name, keys=keys, parallelism=parallelism, metadata=states_metadata ) - call = self._stub.GetBulkState(req, metadata=metadata) - response = await call + + try: + call = self._stub.GetBulkState(req, metadata=metadata) + response = await call + except AioRpcError as err: + raise DaprGrpcError(err) from err items = [] for item in response.items: @@ -601,8 +615,12 @@ async def query_state( if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0: raise ValueError('State store name cannot be empty') req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata) - call = self._stub.QueryStateAlpha1(req) - response = await call + + try: + call = self._stub.QueryStateAlpha1(req) + response = await call + except AioRpcError as err: + raise DaprGrpcError(err) from err results = [] for item in response.results: @@ -691,9 +709,12 @@ async def save_state( ) req = api_v1.SaveStateRequest(store_name=store_name, states=[state]) - call = self._stub.SaveState(req, metadata=metadata) - await call - return DaprResponse(headers=await call.initial_metadata()) + try: + call = self._stub.SaveState(req, metadata=metadata) + await call + return DaprResponse(headers=await call.initial_metadata()) + except AioRpcError as e: + raise DaprInternalError(e.details()) from e async def save_bulk_state( self, store_name: str, states: List[StateItem], metadata: Optional[MetadataTuple] = None @@ -749,8 +770,13 @@ async def save_bulk_state( ] req = api_v1.SaveStateRequest(store_name=store_name, states=req_states) - call = self._stub.SaveState(req, metadata=metadata) - await call + + try: + call = self._stub.SaveState(req, metadata=metadata) + await call + except AioRpcError as err: + raise DaprGrpcError(err) from err + return DaprResponse(headers=await call.initial_metadata()) async def execute_state_transaction( @@ -815,8 +841,13 @@ async def execute_state_transaction( req = api_v1.ExecuteStateTransactionRequest( storeName=store_name, operations=req_ops, metadata=transactional_metadata ) - call = self._stub.ExecuteStateTransaction(req, metadata=metadata) - await call + + try: + call = self._stub.ExecuteStateTransaction(req, metadata=metadata) + await call + except AioRpcError as err: + raise DaprGrpcError(err) from err + return DaprResponse(headers=await call.initial_metadata()) async def delete_state( @@ -880,8 +911,13 @@ async def delete_state( options=state_options, metadata=state_metadata, ) - call = self._stub.DeleteState(req, metadata=metadata) - await call + + try: + call = self._stub.DeleteState(req, metadata=metadata) + await call + except AioRpcError as err: + raise DaprGrpcError(err) from err + return DaprResponse(headers=await call.initial_metadata()) async def get_secret( @@ -1522,8 +1558,13 @@ async def get_metadata(self) -> GetMetadataResponse: information about supported features in the form of component capabilities. """ - call = self._stub.GetMetadata(GrpcEmpty()) - _resp = await call + + try: + call = self._stub.GetMetadata(GrpcEmpty()) + _resp = await call + except AioRpcError as err: + raise DaprGrpcError(err) from err + response: api_v1.GetMetadataResponse = _resp # type alias # Convert to more pythonic formats active_actors_count = { diff --git a/tests/clients/test_dapr_async_grpc_client.py b/tests/clients/test_dapr_async_grpc_client.py index f0539f76..dbef2fb8 100644 --- a/tests/clients/test_dapr_async_grpc_client.py +++ b/tests/clients/test_dapr_async_grpc_client.py @@ -20,8 +20,11 @@ from unittest.mock import patch +from google.rpc import status_pb2, code_pb2 + from dapr.aio.clients.grpc.client import DaprGrpcClientAsync from dapr.aio.clients import DaprClient +from dapr.clients.exceptions import DaprGrpcError from dapr.proto import common_v1 from .fake_dapr_server import FakeDaprSidecar from dapr.conf import settings @@ -202,10 +205,18 @@ async def test_invoke_binding_no_create(self): async def test_publish_event(self): dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') - resp = await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'haha') + resp = await dapr.publish_event( + pubsub_name='pubsub', topic_name='example', data=b'test_data' + ) self.assertEqual(2, len(resp.headers)) - self.assertEqual(['haha'], resp.headers['hdata']) + self.assertEqual(['test_data'], resp.headers['hdata']) + + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError): + await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'test_data') async def test_publish_event_with_content_type(self): dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') @@ -292,12 +303,19 @@ async def test_get_save_delete_state(self): self.assertEqual(resp.data, b'') self.assertEqual(resp.etag, '') + # Check a DaprGrpcError is raised + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError) as context: + await dapr.get_state(store_name='my_statestore', key='key||') + await dapr.delete_state(store_name='statestore', key=key) resp = await dapr.get_state(store_name='statestore', key=key) self.assertEqual(resp.data, b'') self.assertEqual(resp.etag, '') - with self.assertRaises(Exception) as context: + with self.assertRaises(DaprGrpcError) as context: await dapr.delete_state( store_name='statestore', key=key, state_metadata={'must_delete': '1'} ) @@ -359,7 +377,20 @@ async def test_transaction_then_get_states(self): self.assertEqual(resp.items[1].key, another_key) self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) - async def test_save_then_get_states(self): + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError): + await dapr.execute_state_transaction( + store_name='statestore', + operations=[ + TransactionalStateOperation(key=key, data=value, etag='foo'), + TransactionalStateOperation(key=another_key, data=another_value), + ], + transactional_metadata={'metakey': 'metavalue'}, + ) + + async def test_bulk_save_then_get_states(self): dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') key = str(uuid.uuid4()) @@ -394,6 +425,27 @@ async def test_save_then_get_states(self): self.assertEqual(resp.items[1].etag, '1') self.assertEqual(resp.items[1].data, to_bytes(another_value.upper())) + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError): + await dapr.save_bulk_state( + store_name='statestore', + states=[ + StateItem(key=key, value=value, metadata={'capitalize': '1'}), + StateItem(key=another_key, value=another_value, etag='1'), + ], + metadata=(('metakey', 'metavalue'),), + ) + + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError): + await dapr.get_bulk_state( + store_name='statestore', keys=[key, another_key], states_metadata={'upper': '1'} + ) + async def test_get_secret(self): dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') key1 = 'key_1' @@ -512,6 +564,15 @@ async def test_query_state(self): self.assertEqual(resp.results[0].key, '3') self.assertEqual(len(resp.results), 3) + self._fake_dapr_server.raise_exception_on_next_call( + status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message') + ) + with self.assertRaises(DaprGrpcError): + await dapr.query_state( + store_name='statestore', + query=json.dumps({'filter': {}, 'page': {'limit': 2}}), + ) + async def test_shutdown(self): dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}') await dapr.shutdown()