From e5defddd33f704b663eb94b2eeefd2f7ccac955d Mon Sep 17 00:00:00 2001 From: Dmytro Sadovnychyi Date: Sat, 11 Jan 2025 01:09:49 +0900 Subject: [PATCH] Add retry logic to each batch method of the GCS IO (#33539) * Add retry logic to each batch method of the GCS IO A transient error might occur when writing a lot of shards to GCS, and right now the GCS IO does not have any retry logic in place: https://github.com/apache/beam/blob/a06454a2/sdks/python/apache_beam/io/gcp/gcsio.py#L269 It means that in such cases the entire bundle of elements fails, and then Beam itself will attempt to retry the entire bundle, and will fail the job if it exceeds the number of retries. This change adds new logic to retry only failed requests, and uses the typical exponential backoff strategy. Note that this change accesses a private method (`_predicate`) of the retry object, which we could avoid by basically copying the logic over here. But existing code already accesses `_responses` property so maybe it's not a big deal. https://github.com/apache/beam/blob/b4c3a4ff/sdks/python/apache_beam/io/gcp/gcsio.py#L297 Existing (unresolved) issue in the GCS client library: https://github.com/googleapis/python-storage/issues/1277 * Catch correct exception type in `_batch_with_retry` The `RetryError` would be always raised since the retry decorator would catch all HTTP-related exceptions. * Update chanelog with GCSIO retry logic fix --- CHANGES.md | 1 + sdks/python/apache_beam/io/gcp/gcsio.py | 75 ++++++++++++++------ sdks/python/apache_beam/io/gcp/gcsio_test.py | 68 ++++++++++++++++++ 3 files changed, 122 insertions(+), 22 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1f8f13305c83..44f5fe88c4dc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Upgraded to protobuf 4 (Java) ([#33192](https://github.com/apache/beam/issues/33192)). +* [GCSIO] Added retry logic to each batch method of the GCS IO (Python) ([#33539](https://github.com/apache/beam/pull/33539)) ## Breaking Changes diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index e0dcffa86dff..3e2f5d4cf635 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -35,8 +35,10 @@ from typing import Optional from typing import Union +from google.api_core.exceptions import RetryError from google.cloud import storage from google.cloud.exceptions import NotFound +from google.cloud.exceptions import from_http_response from google.cloud.storage.fileio import BlobReader from google.cloud.storage.fileio import BlobWriter from google.cloud.storage.retry import DEFAULT_RETRY @@ -264,9 +266,45 @@ def delete(self, path): except NotFound: return + def _batch_with_retry(self, requests, fn): + current_requests = [*enumerate(requests)] + responses = [None for _ in current_requests] + + @self._storage_client_retry + def run_with_retry(): + current_batch = self.client.batch(raise_exception=False) + with current_batch: + for _, request in current_requests: + fn(request) + last_retryable_exception = None + for (i, current_pair), response in zip( + [*current_requests], current_batch._responses + ): + responses[i] = response + should_retry = ( + response.status_code >= 400 and + self._storage_client_retry._predicate(from_http_response(response))) + if should_retry: + last_retryable_exception = from_http_response(response) + else: + current_requests.remove((i, current_pair)) + if last_retryable_exception: + raise last_retryable_exception + + try: + run_with_retry() + except RetryError: + pass + + return responses + + def _delete_batch_request(self, path): + bucket_name, blob_name = parse_gcs_path(path) + bucket = self.client.bucket(bucket_name) + bucket.delete_blob(blob_name) + def delete_batch(self, paths): """Deletes the objects at the given GCS paths. - Warning: any exception during batch delete will NOT be retried. Args: paths: List of GCS file path patterns or Dict with GCS file path patterns @@ -285,16 +323,11 @@ def delete_batch(self, paths): current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE] else: current_paths = paths[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for path in current_paths: - bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.bucket(bucket_name) - bucket.delete_blob(blob_name) - + responses = self._batch_with_retry( + current_paths, self._delete_batch_request) for i, path in enumerate(current_paths): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400 and resp.status_code != 404: error_code = resp.status_code final_results.append((path, error_code)) @@ -334,9 +367,16 @@ def copy(self, src, dest): source_generation=src_generation, retry=self._storage_client_retry) + def _copy_batch_request(self, pair): + src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) + dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) + src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) + def copy_batch(self, src_dest_pairs): """Copies the given GCS objects from src to dest. - Warning: any exception during batch copy will NOT be retried. Args: src_dest_pairs: list of (src, dest) tuples of gs:/// files @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs): current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE] else: current_pairs = src_dest_pairs[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for pair in current_pairs: - src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) - dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) - src_bucket = self.client.bucket(src_bucket_name) - src_blob = src_bucket.blob(src_blob_name) - dest_bucket = self.client.bucket(dest_bucket_name) - - src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) - + responses = self._batch_with_retry( + current_pairs, self._copy_batch_request) for i, pair in enumerate(current_pairs): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400: error_code = resp.status_code final_results.append((pair[0], pair[1], error_code)) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 7b79030b4b71..1faae2b2a8f1 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -482,6 +482,74 @@ def test_copy(self): 'gs://gcsio-test/non-existent', 'gs://gcsio-test/non-existent-destination') + @staticmethod + def _fake_batch_responses(status_codes): + return mock.Mock( + __enter__=mock.Mock(), + __exit__=mock.Mock(), + _responses=[ + mock.Mock( + **{ + 'json.return_value': { + 'error': { + 'message': 'error' + } + }, + 'request.method': 'BATCH', + 'request.url': 'contentid://None', + }, + status_code=code, + ) for code in status_codes + ], + ) + + @mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3) + @mock.patch('time.sleep', mock.Mock()) + def test_copy_batch(self): + src_dest_pairs = [ + (f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt') + for i in range(7) + ] + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock( + side_effect=[ + self._fake_batch_responses([200, 404, 429]), + self._fake_batch_responses([429]), + self._fake_batch_responses([429]), + self._fake_batch_responses([200]), + self._fake_batch_responses([200, 429, 200]), + self._fake_batch_responses([200]), + self._fake_batch_responses([200]), + ]), + )) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None), + ('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404), + ('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None), + ('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None), + ('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None), + ('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None), + ('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None), + ] + self.assertEqual(results, expected) + + @mock.patch('time.sleep', mock.Mock()) + @mock.patch('time.monotonic', mock.Mock(side_effect=[0, 120])) + def test_copy_batch_timeout_exceeded(self): + src_dest_pairs = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt') + ] + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock(side_effect=[self._fake_batch_responses([429])]))) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', 429), + ] + self.assertEqual(results, expected) + def test_copytree(self): src_dir_name = 'gs://gcsio-test/source/' dest_dir_name = 'gs://gcsio-test/dest/'