Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds error handling for the async client #671

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 62 additions & 21 deletions dapr/aio/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor,
StreamStreamClientInterceptor,
AioRpcError,
)

from dapr.clients.exceptions import DaprInternalError
from dapr.clients.exceptions import DaprInternalError, DaprGrpcError
from dapr.clients.grpc._state import StateOptions, StateItem
from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -446,9 +447,12 @@ async def publish_event(
metadata=publish_metadata,
)

call = self._stub.PublishEvent(req, metadata=metadata)
# response is google.protobuf.Empty
await call
try:
call = self._stub.PublishEvent(req, metadata=metadata)
# response is google.protobuf.Empty
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(await call.initial_metadata())

Expand Down Expand Up @@ -491,9 +495,15 @@ async def get_state(

if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0:
raise ValueError('State store name cannot be empty')

req = api_v1.GetStateRequest(store_name=store_name, key=key, metadata=state_metadata)
call = self._stub.GetState(req, metadata=metadata)
response = await call

try:
call = self._stub.GetState(req, metadata=metadata)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return StateResponse(
data=response.data, etag=response.etag, headers=await call.initial_metadata()
)
Expand Down Expand Up @@ -542,8 +552,12 @@ async def get_bulk_state(
req = api_v1.GetBulkStateRequest(
store_name=store_name, keys=keys, parallelism=parallelism, metadata=states_metadata
)
call = self._stub.GetBulkState(req, metadata=metadata)
response = await call

try:
call = self._stub.GetBulkState(req, metadata=metadata)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

items = []
for item in response.items:
Expand Down Expand Up @@ -601,8 +615,12 @@ async def query_state(
if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0:
raise ValueError('State store name cannot be empty')
req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata)
call = self._stub.QueryStateAlpha1(req)
response = await call

try:
call = self._stub.QueryStateAlpha1(req)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

results = []
for item in response.results:
Expand Down Expand Up @@ -691,9 +709,12 @@ async def save_state(
)

req = api_v1.SaveStateRequest(store_name=store_name, states=[state])
call = self._stub.SaveState(req, metadata=metadata)
await call
return DaprResponse(headers=await call.initial_metadata())
try:
call = self._stub.SaveState(req, metadata=metadata)
await call
return DaprResponse(headers=await call.initial_metadata())
except AioRpcError as e:
raise DaprInternalError(e.details()) from e

async def save_bulk_state(
self, store_name: str, states: List[StateItem], metadata: Optional[MetadataTuple] = None
Expand Down Expand Up @@ -749,8 +770,13 @@ async def save_bulk_state(
]

req = api_v1.SaveStateRequest(store_name=store_name, states=req_states)
call = self._stub.SaveState(req, metadata=metadata)
await call

try:
call = self._stub.SaveState(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def execute_state_transaction(
Expand Down Expand Up @@ -815,8 +841,13 @@ async def execute_state_transaction(
req = api_v1.ExecuteStateTransactionRequest(
storeName=store_name, operations=req_ops, metadata=transactional_metadata
)
call = self._stub.ExecuteStateTransaction(req, metadata=metadata)
await call

try:
call = self._stub.ExecuteStateTransaction(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def delete_state(
Expand Down Expand Up @@ -880,8 +911,13 @@ async def delete_state(
options=state_options,
metadata=state_metadata,
)
call = self._stub.DeleteState(req, metadata=metadata)
await call

try:
call = self._stub.DeleteState(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def get_secret(
Expand Down Expand Up @@ -1522,8 +1558,13 @@ async def get_metadata(self) -> GetMetadataResponse:
information about supported features in the form of component
capabilities.
"""
call = self._stub.GetMetadata(GrpcEmpty())
_resp = await call

try:
call = self._stub.GetMetadata(GrpcEmpty())
_resp = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

response: api_v1.GetMetadataResponse = _resp # type alias
# Convert to more pythonic formats
active_actors_count = {
Expand Down
69 changes: 65 additions & 4 deletions tests/clients/test_dapr_async_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

from unittest.mock import patch

from google.rpc import status_pb2, code_pb2

from dapr.aio.clients.grpc.client import DaprGrpcClientAsync
from dapr.aio.clients import DaprClient
from dapr.clients.exceptions import DaprGrpcError
from dapr.proto import common_v1
from .fake_dapr_server import FakeDaprSidecar
from dapr.conf import settings
Expand Down Expand Up @@ -202,10 +205,18 @@ async def test_invoke_binding_no_create(self):

async def test_publish_event(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
resp = await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'haha')
resp = await dapr.publish_event(
pubsub_name='pubsub', topic_name='example', data=b'test_data'
)

self.assertEqual(2, len(resp.headers))
self.assertEqual(['haha'], resp.headers['hdata'])
self.assertEqual(['test_data'], resp.headers['hdata'])

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'test_data')

async def test_publish_event_with_content_type(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
Expand Down Expand Up @@ -292,12 +303,19 @@ async def test_get_save_delete_state(self):
self.assertEqual(resp.data, b'')
self.assertEqual(resp.etag, '')

# Check a DaprGrpcError is raised
self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError) as context:
await dapr.get_state(store_name='my_statestore', key='key||')

await dapr.delete_state(store_name='statestore', key=key)
resp = await dapr.get_state(store_name='statestore', key=key)
self.assertEqual(resp.data, b'')
self.assertEqual(resp.etag, '')

with self.assertRaises(Exception) as context:
with self.assertRaises(DaprGrpcError) as context:
await dapr.delete_state(
store_name='statestore', key=key, state_metadata={'must_delete': '1'}
)
Expand Down Expand Up @@ -359,7 +377,20 @@ async def test_transaction_then_get_states(self):
self.assertEqual(resp.items[1].key, another_key)
self.assertEqual(resp.items[1].data, to_bytes(another_value.upper()))

async def test_save_then_get_states(self):
self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.execute_state_transaction(
store_name='statestore',
operations=[
TransactionalStateOperation(key=key, data=value, etag='foo'),
TransactionalStateOperation(key=another_key, data=another_value),
],
transactional_metadata={'metakey': 'metavalue'},
)

async def test_bulk_save_then_get_states(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')

key = str(uuid.uuid4())
Expand Down Expand Up @@ -394,6 +425,27 @@ async def test_save_then_get_states(self):
self.assertEqual(resp.items[1].etag, '1')
self.assertEqual(resp.items[1].data, to_bytes(another_value.upper()))

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.save_bulk_state(
store_name='statestore',
states=[
StateItem(key=key, value=value, metadata={'capitalize': '1'}),
StateItem(key=another_key, value=another_value, etag='1'),
],
metadata=(('metakey', 'metavalue'),),
)

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.get_bulk_state(
store_name='statestore', keys=[key, another_key], states_metadata={'upper': '1'}
)

async def test_get_secret(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
key1 = 'key_1'
Expand Down Expand Up @@ -512,6 +564,15 @@ async def test_query_state(self):
self.assertEqual(resp.results[0].key, '3')
self.assertEqual(len(resp.results), 3)

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.query_state(
store_name='statestore',
query=json.dumps({'filter': {}, 'page': {'limit': 2}}),
)

async def test_shutdown(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
await dapr.shutdown()
Expand Down
Loading