ajdub508 commented on code in PR #28091:
URL: https://github.com/apache/beam/pull/28091#discussion_r1327060993


##########
sdks/python/apache_beam/io/gcp/bigquery_test.py:
##########
@@ -931,255 +929,583 @@ def test_copy_load_job_exception(self, exception_type, 
error_message):
     'GCP dependencies are not installed')
 class BigQueryStreamingInsertsErrorHandling(unittest.TestCase):
 
-  # Using https://cloud.google.com/bigquery/docs/error-messages and
-  # https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #    /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on first attempt
+      # transient error retried and succeeds on second attempt, 0 rows sent to
+      # failed rows
       param(
-          exception_type=exceptions.Forbidden if exceptions else None,
-          error_reason='rateLimitExceeded'),
-      param(
-          exception_type=exceptions.DeadlineExceeded if exceptions else None,
-          error_reason='somereason'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None, None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=[]),

Review Comment:
   > I may be missing something here. I see `max_retries=len(insert_response) - 
1`, which means 0 retries for this case right? why is there a second attempt
   
   There are actually 2 responses there with the second one being `None`, but 
the formatting is not consistent and makes it difficult to read. I'll get that 
on its own line to make it easier to read.



##########
sdks/python/apache_beam/io/gcp/bigquery_test.py:
##########
@@ -931,255 +929,583 @@ def test_copy_load_job_exception(self, exception_type, 
error_message):
     'GCP dependencies are not installed')
 class BigQueryStreamingInsertsErrorHandling(unittest.TestCase):
 
-  # Using https://cloud.google.com/bigquery/docs/error-messages and
-  # https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #    /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on first attempt
+      # transient error retried and succeeds on second attempt, 0 rows sent to
+      # failed rows
       param(
-          exception_type=exceptions.Forbidden if exceptions else None,
-          error_reason='rateLimitExceeded'),
-      param(
-          exception_type=exceptions.DeadlineExceeded if exceptions else None,
-          error_reason='somereason'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None, None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=[]),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed rows after hitting max_retries
       param(
-          exception_type=exceptions.ServiceUnavailable if exceptions else None,
-          error_reason='backendError'),
-      param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='internalError'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None,
+            exceptions.TooManyRequests if exceptions else None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=['value1', 'value3', 'value5']),
+      # reason in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed_rows after hitting max_retries
       param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='backendError'),

Review Comment:
   Thanks, will add back.



##########
sdks/python/apache_beam/io/gcp/bigquery_test.py:
##########
@@ -931,255 +929,583 @@ def test_copy_load_job_exception(self, exception_type, 
error_message):
     'GCP dependencies are not installed')
 class BigQueryStreamingInsertsErrorHandling(unittest.TestCase):
 
-  # Using https://cloud.google.com/bigquery/docs/error-messages and
-  # https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #    /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on first attempt
+      # transient error retried and succeeds on second attempt, 0 rows sent to
+      # failed rows
       param(
-          exception_type=exceptions.Forbidden if exceptions else None,
-          error_reason='rateLimitExceeded'),
-      param(
-          exception_type=exceptions.DeadlineExceeded if exceptions else None,
-          error_reason='somereason'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None, None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=[]),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed rows after hitting max_retries
       param(
-          exception_type=exceptions.ServiceUnavailable if exceptions else None,
-          error_reason='backendError'),
-      param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='internalError'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None,
+            exceptions.TooManyRequests if exceptions else None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=['value1', 'value3', 'value5']),
+      # reason in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed_rows after hitting max_retries
       param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='backendError'),
+          insert_response=[
+            exceptions.Forbidden if exceptions else None,
+            exceptions.Forbidden if exceptions else None],
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          failed_rows=['value1', 'value3', 'value5']),
   ])
-  @mock.patch('time.sleep')
-  @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
-  def test_insert_all_retries_if_structured_retriable(
-      self,
-      mock_send,
-      unused_mock_sleep,
-      exception_type=None,
-      error_reason=None):
-    # In this test, a BATCH pipeline will retry the known RETRIABLE errors.
-    mock_send.side_effect = [
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-    ]
+  def test_insert_rows_json_exception_retry_always(
+      self, insert_response, error_reason, failed_rows):
+    # In this test, a pipeline will always retry all caught exception types
+    # since RetryStrategy is not set and defaults to RETRY_ALWAYS
+    with mock.patch('time.sleep'):
+      call_counter = 0
+      mock_response = mock.Mock()
+      mock_response.reason = error_reason
 
-    with self.assertRaises(Exception) as exc:
-      with beam.Pipeline() as p:
-        _ = (
+      def store_callback(table, **kwargs):
+        nonlocal call_counter
+        # raise exception if insert_response element is an exception
+        if insert_response[call_counter]:
+          exception_type = insert_response[call_counter]
+          call_counter += 1
+          raise exception_type('some exception', response=mock_response)
+        # return empty list if not insert_response element, indicating
+        # successful call to insert_rows_json
+        else:
+          call_counter += 1
+          return []
+
+      client = mock.Mock()
+      client.insert_rows_json.side_effect = store_callback
+
+      # Using the bundle based direct runner to avoid pickling problems
+      # with mocks.
+      with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
+        bq_write_out = (
             p
             | beam.Create([{
-                'columnA': 'value1'
+                'columnA': 'value1', 'columnB': 'value2'
+            }, {
+                'columnA': 'value3', 'columnB': 'value4'
+            }, {
+                'columnA': 'value5', 'columnB': 'value6'
             }])
-            | WriteToBigQuery(
-                table='project:dataset.table',
-                schema={
-                    'fields': [{
-                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
-                    }]
-                },
+            # Using _StreamToBigQuery in order to be able to pass max_retries
+            # in order to limit run time of test with RETRY_ALWAYS
+            | _StreamToBigQuery(
+                table_reference='project:dataset.table',
+                table_side_inputs=[],
+                schema_side_inputs=[],
+                schema='anyschema',
+                batch_size=None,
+                triggering_frequency=None,
                 create_disposition='CREATE_NEVER',
-                method='STREAMING_INSERTS'))
-    self.assertEqual(4, mock_send.call_count)
-    self.assertIn('some retriable exception', exc.exception.args[0])
+                write_disposition=None,
+                kms_key=None,
+                retry_strategy=RetryStrategy.RETRY_ALWAYS,
+                additional_bq_parameters=[],
+                ignore_insert_ids=False,
+                ignore_unknown_columns=False,
+                with_auto_sharding=False,
+                test_client=client,
+                max_retries=len(insert_response) - 1,
+                num_streaming_keys=500))
+
+        failed_values = (
+            bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS]
+            | beam.Map(lambda x: x[1]['columnA']))
+
+        assert_that(failed_values, equal_to(failed_rows))
 
-  # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #   /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
       param(
-          exception_type=requests.exceptions.ConnectionError,
-          error_message='some connection error'),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          streaming=False),
       param(
-          exception_type=requests.exceptions.Timeout,
-          error_message='some timeout error'),
+          exception_type=exceptions.Forbidden if exceptions else None,
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          streaming=False),
       param(
-          exception_type=ConnectionError,
-          error_message='some py connection error'),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          streaming=True),
       param(
-          exception_type=exceptions.BadGateway if exceptions else None,
-          error_message='some badgateway error'),

Review Comment:
   Thanks, will add back.



##########
sdks/python/apache_beam/io/gcp/bigquery_test.py:
##########
@@ -931,255 +929,583 @@ def test_copy_load_job_exception(self, exception_type, 
error_message):
     'GCP dependencies are not installed')
 class BigQueryStreamingInsertsErrorHandling(unittest.TestCase):
 
-  # Using https://cloud.google.com/bigquery/docs/error-messages and
-  # https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #    /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on first attempt
+      # transient error retried and succeeds on second attempt, 0 rows sent to
+      # failed rows
       param(
-          exception_type=exceptions.Forbidden if exceptions else None,
-          error_reason='rateLimitExceeded'),
-      param(
-          exception_type=exceptions.DeadlineExceeded if exceptions else None,
-          error_reason='somereason'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None, None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=[]),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed rows after hitting max_retries
       param(
-          exception_type=exceptions.ServiceUnavailable if exceptions else None,
-          error_reason='backendError'),
-      param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='internalError'),
+          insert_response=[
+            exceptions.TooManyRequests if exceptions else None,
+            exceptions.TooManyRequests if exceptions else None],
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_rows=['value1', 'value3', 'value5']),
+      # reason in _NON_TRANSIENT_ERRORS for row 1 on both attempts, sent to
+      # failed_rows after hitting max_retries
       param(
-          exception_type=exceptions.InternalServerError if exceptions else 
None,
-          error_reason='backendError'),
+          insert_response=[
+            exceptions.Forbidden if exceptions else None,
+            exceptions.Forbidden if exceptions else None],
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          failed_rows=['value1', 'value3', 'value5']),
   ])
-  @mock.patch('time.sleep')
-  @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
-  def test_insert_all_retries_if_structured_retriable(
-      self,
-      mock_send,
-      unused_mock_sleep,
-      exception_type=None,
-      error_reason=None):
-    # In this test, a BATCH pipeline will retry the known RETRIABLE errors.
-    mock_send.side_effect = [
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-        exception_type(
-            'some retriable exception', errors=[{
-                'reason': error_reason
-            }]),
-    ]
+  def test_insert_rows_json_exception_retry_always(
+      self, insert_response, error_reason, failed_rows):
+    # In this test, a pipeline will always retry all caught exception types
+    # since RetryStrategy is not set and defaults to RETRY_ALWAYS
+    with mock.patch('time.sleep'):
+      call_counter = 0
+      mock_response = mock.Mock()
+      mock_response.reason = error_reason
 
-    with self.assertRaises(Exception) as exc:
-      with beam.Pipeline() as p:
-        _ = (
+      def store_callback(table, **kwargs):
+        nonlocal call_counter
+        # raise exception if insert_response element is an exception
+        if insert_response[call_counter]:
+          exception_type = insert_response[call_counter]
+          call_counter += 1
+          raise exception_type('some exception', response=mock_response)
+        # return empty list if not insert_response element, indicating
+        # successful call to insert_rows_json
+        else:
+          call_counter += 1
+          return []
+
+      client = mock.Mock()
+      client.insert_rows_json.side_effect = store_callback
+
+      # Using the bundle based direct runner to avoid pickling problems
+      # with mocks.
+      with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
+        bq_write_out = (
             p
             | beam.Create([{
-                'columnA': 'value1'
+                'columnA': 'value1', 'columnB': 'value2'
+            }, {
+                'columnA': 'value3', 'columnB': 'value4'
+            }, {
+                'columnA': 'value5', 'columnB': 'value6'
             }])
-            | WriteToBigQuery(
-                table='project:dataset.table',
-                schema={
-                    'fields': [{
-                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
-                    }]
-                },
+            # Using _StreamToBigQuery in order to be able to pass max_retries
+            # in order to limit run time of test with RETRY_ALWAYS
+            | _StreamToBigQuery(
+                table_reference='project:dataset.table',
+                table_side_inputs=[],
+                schema_side_inputs=[],
+                schema='anyschema',
+                batch_size=None,
+                triggering_frequency=None,
                 create_disposition='CREATE_NEVER',
-                method='STREAMING_INSERTS'))
-    self.assertEqual(4, mock_send.call_count)
-    self.assertIn('some retriable exception', exc.exception.args[0])
+                write_disposition=None,
+                kms_key=None,
+                retry_strategy=RetryStrategy.RETRY_ALWAYS,
+                additional_bq_parameters=[],
+                ignore_insert_ids=False,
+                ignore_unknown_columns=False,
+                with_auto_sharding=False,
+                test_client=client,
+                max_retries=len(insert_response) - 1,
+                num_streaming_keys=500))
+
+        failed_values = (
+            bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS]
+            | beam.Map(lambda x: x[1]['columnA']))
+
+        assert_that(failed_values, equal_to(failed_rows))
 
-  # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #   /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
       param(
-          exception_type=requests.exceptions.ConnectionError,
-          error_message='some connection error'),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          streaming=False),
       param(
-          exception_type=requests.exceptions.Timeout,
-          error_message='some timeout error'),
+          exception_type=exceptions.Forbidden if exceptions else None,
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          streaming=False),
       param(
-          exception_type=ConnectionError,
-          error_message='some py connection error'),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          streaming=True),
       param(
-          exception_type=exceptions.BadGateway if exceptions else None,
-          error_message='some badgateway error'),
+          exception_type=exceptions.Forbidden if exceptions else None,
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          streaming=True),
   ])
   @mock.patch('time.sleep')
   @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
-  def test_insert_all_retries_if_unstructured_retriable(
+  def test_insert_rows_json_exception_retry_never(
       self,
       mock_send,
       unused_mock_sleep,
-      exception_type=None,
-      error_message=None):
-    # In this test, a BATCH pipeline will retry the unknown RETRIABLE errors.
+      exception_type,
+      error_reason,
+      streaming=False):
+    # In this test, a pipeline will never retry caught exception types
+    # since RetryStrategy is set to RETRY_NEVER
+    mock_response = mock.Mock()
+    mock_response.reason = error_reason
     mock_send.side_effect = [
-        exception_type(error_message),
-        exception_type(error_message),
-        exception_type(error_message),
-        exception_type(error_message),
+        exception_type('some exception', response=mock_response)
     ]
+    opt = StandardOptions()
+    opt.streaming = streaming
+    with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
+      bq_write_out = (
+          p
+          | beam.Create([{
+              'columnA': 'value1'
+          }, {
+              'columnA': 'value2'
+          }])
+          | WriteToBigQuery(
+              table='project:dataset.table',
+              schema={
+                  'fields': [{
+                      'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
+                  }]
+              },
+              create_disposition='CREATE_NEVER',
+              method='STREAMING_INSERTS',
+              insert_retry_strategy=RetryStrategy.RETRY_NEVER))
+      failed_values = (
+          bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]
+          | beam.Map(lambda x: x[1]['columnA']))
 
-    with self.assertRaises(Exception) as exc:
-      with beam.Pipeline() as p:
-        _ = (
-            p
-            | beam.Create([{
-                'columnA': 'value1'
-            }])
-            | WriteToBigQuery(
-                table='project:dataset.table',
-                schema={
-                    'fields': [{
-                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
-                    }]
-                },
-                create_disposition='CREATE_NEVER',
-                method='STREAMING_INSERTS'))
-    self.assertEqual(4, mock_send.call_count)
-    self.assertIn(error_message, exc.exception.args[0])
+      assert_that(failed_values, equal_to(['value1', 'value2']))
+
+    self.assertEqual(1, mock_send.call_count)
 
-  # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #   /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of exceptions from  https://googleapis.dev
+  #     
/python/google-api-core/latest/_modules/google/api_core/exceptions.html.
+  # Choosing some exceptions that produce reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
       param(
-          exception_type=retry.PermanentException,
-          error_args=('nonretriable', )),
-      param(
-          exception_type=exceptions.BadRequest if exceptions else None,
-          error_args=(
-              'forbidden morbidden', [{
-                  'reason': 'nonretriablereason'
-              }])),
-      param(
-          exception_type=exceptions.BadRequest if exceptions else None,
-          error_args=('BAD REQUEST!', [{
-              'reason': 'nonretriablereason'
-          }])),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_values=[],
+          expected_call_count=2,
+          streaming=False),
       param(
-          exception_type=exceptions.MethodNotAllowed if exceptions else None,
-          error_args=(
-              'method not allowed!', [{
-                  'reason': 'nonretriablereason'
-              }])),
-      param(
-          exception_type=exceptions.MethodNotAllowed if exceptions else None,
-          error_args=('method not allowed!', 'args')),
+          exception_type=exceptions.Forbidden if exceptions else None,
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          failed_values=['value1', 'value2'],
+          expected_call_count=1,
+          streaming=False),
       param(
-          exception_type=exceptions.Unknown if exceptions else None,
-          error_args=('unknown!', 'args')),
+          exception_type=exceptions.TooManyRequests if exceptions else None,
+          error_reason='Too Many Requests', # not in _NON_TRANSIENT_ERRORS
+          failed_values=[],
+          expected_call_count=2,
+          streaming=True),
       param(
-          exception_type=exceptions.Aborted if exceptions else None,
-          error_args=('abortet!', 'abort')),
+          exception_type=exceptions.Forbidden if exceptions else None,
+          error_reason='Forbidden', # in _NON_TRANSIENT_ERRORS
+          failed_values=['value1', 'value2'],
+          expected_call_count=1,
+          streaming=True),
   ])
   @mock.patch('time.sleep')
   @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
-  def test_insert_all_unretriable_errors(
-      self, mock_send, unused_mock_sleep, exception_type=None, 
error_args=None):
-    # In this test, a BATCH pipeline will retry the unknown RETRIABLE errors.
+  def test_insert_rows_json_exception_retry_on_transient_error(
+      self,
+      mock_send,
+      unused_mock_sleep,
+      exception_type,
+      error_reason,
+      failed_values,
+      expected_call_count,
+      streaming=False):
+    # In this test, a pipeline will only retry caught exception types
+    # with reasons that are not in _NON_TRANSIENT_ERRORS since RetryStrategy is
+    # set to RETRY_ON_TRANSIENT_ERROR
+    mock_response = mock.Mock()
+    mock_response.reason = error_reason
     mock_send.side_effect = [
-        exception_type(*error_args),
-        exception_type(*error_args),
-        exception_type(*error_args),
-        exception_type(*error_args),
+        exception_type('some exception', response=mock_response),
+        # Return no exception and no errors on 2nd call, if there is a 2nd call
+        []
     ]
+    opt = StandardOptions()
+    opt.streaming = streaming
 
-    with self.assertRaises(Exception):
-      with beam.Pipeline() as p:
-        _ = (
+    with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
+      bq_write_out = (
+          p
+          | beam.Create([{
+              'columnA': 'value1'
+          }, {
+              'columnA': 'value2'
+          }])
+          | WriteToBigQuery(
+              table='project:dataset.table',
+              schema={
+                  'fields': [{
+                      'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
+                  }]
+              },
+              create_disposition='CREATE_NEVER',
+              method='STREAMING_INSERTS',
+              insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR))
+      failed_values_out = (
+          bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS]
+          | beam.Map(lambda x: x[1]['columnA']))
+
+      assert_that(failed_values_out, equal_to(failed_values))
+    self.assertEqual(expected_call_count, mock_send.call_count)
+
+  # Running tests with a variety of error reasons from
+  # https://cloud.google.com/bigquery/docs/error-messages
+  # This covers the scenario when
+  # the google.cloud.bigquery.Client.insert_rows_json call returns an error 
list
+  # rather than raising an exception.
+  # Choosing some error reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
+  @parameterized.expand([
+      # reason in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
+      param(
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          failed_rows=['value1']),
+      # reason in _NON_TRANSIENT_ERRORS for row 1
+      # reason not in _NON_TRANSIENT_ERRORS for row 2 on 1st run
+      # row 1 sent to failed_rows
+      param(
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }, {
+                  'index': 1, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          failed_rows=['value1']),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on first attempt
+      # transient error succeeds on second attempt, 0 rows sent to failed rows
+      param(
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+              [],
+          ],
+          failed_rows=[]),
+  ])
+  def test_insert_rows_json_errors_retry_always(
+      self, insert_response, failed_rows, unused_sleep_mock=None):
+    # In this test, a pipeline will always retry all errors
+    # since RetryStrategy is not set and defaults to RETRY_ALWAYS
+    with mock.patch('time.sleep'):
+      call_counter = 0
+
+      def store_callback(table, **kwargs):
+        nonlocal call_counter
+        response = insert_response[call_counter]
+        call_counter += 1
+        return response
+
+      client = mock.Mock()
+      client.insert_rows_json = mock.Mock(side_effect=store_callback)
+
+      # Using the bundle based direct runner to avoid pickling problems
+      # with mocks.
+      with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
+        bq_write_out = (
             p
             | beam.Create([{
-                'columnA': 'value1'
+                'columnA': 'value1', 'columnB': 'value2'
+            }, {
+                'columnA': 'value3', 'columnB': 'value4'
+            }, {
+                'columnA': 'value5', 'columnB': 'value6'
             }])
-            | WriteToBigQuery(
-                table='project:dataset.table',
-                schema={
-                    'fields': [{
-                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
-                    }]
-                },
+            # Using _StreamToBigQuery in order to be able to pass max_retries
+            # in order to limit run time of test with RETRY_ALWAYS
+            | _StreamToBigQuery(
+                table_reference='project:dataset.table',
+                table_side_inputs=[],
+                schema_side_inputs=[],
+                schema='anyschema',
+                batch_size=None,
+                triggering_frequency=None,
                 create_disposition='CREATE_NEVER',
-                method='STREAMING_INSERTS'))
-    self.assertEqual(1, mock_send.call_count)
+                write_disposition=None,
+                kms_key=None,
+                retry_strategy=RetryStrategy.RETRY_ALWAYS,
+                additional_bq_parameters=[],
+                ignore_insert_ids=False,
+                ignore_unknown_columns=False,
+                with_auto_sharding=False,
+                test_client=client,
+                max_retries=len(insert_response) - 1,
+                num_streaming_keys=500))
+
+        failed_values = (
+            bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS]
+            | beam.Map(lambda x: x[1]['columnA']))
+
+        assert_that(failed_values, equal_to(failed_rows))
 
-  # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
-  #    /api_core/exceptions.html
-  # to determine error types and messages to try for retriables.
+  # Running tests with a variety of error reasons from
+  # https://cloud.google.com/bigquery/docs/error-messages
+  # This covers the scenario when
+  # the google.cloud.bigquery.Client.insert_rows_json call returns an error 
list
+  # rather than raising an exception.
+  # Choosing some error reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
   @parameterized.expand([
+      # reason in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
       param(
-          exception_type=retry.PermanentException,
-          error_args=('nonretriable', )),
-      param(
-          exception_type=exceptions.BadRequest if exceptions else None,
-          error_args=(
-              'forbidden morbidden', [{
-                  'reason': 'nonretriablereason'
-              }])),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          streaming=False),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
       param(
-          exception_type=exceptions.BadRequest if exceptions else None,
-          error_args=('BAD REQUEST!', [{
-              'reason': 'nonretriablereason'
-          }])),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+          ],
+          streaming=False),
       param(
-          exception_type=exceptions.MethodNotAllowed if exceptions else None,
-          error_args=(
-              'method not allowed!', [{
-                  'reason': 'nonretriablereason'
-              }])),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          streaming=True),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
       param(
-          exception_type=exceptions.MethodNotAllowed if exceptions else None,
-          error_args=('method not allowed!', 'args')),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+          ],
+          streaming=True),
+  ])
+  @mock.patch('time.sleep')
+  @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
+  def test_insert_rows_json_errors_retry_never(
+      self, mock_send, unused_mock_sleep, insert_response, streaming):
+    # In this test, a pipeline will never retry errors since RetryStrategy is
+    # set to RETRY_NEVER
+    mock_send.side_effect = insert_response
+    opt = StandardOptions()
+    opt.streaming = streaming
+    with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
+      bq_write_out = (
+          p
+          | beam.Create([{
+              'columnA': 'value1'
+          }, {
+              'columnA': 'value2'
+          }])
+          | WriteToBigQuery(
+              table='project:dataset.table',
+              schema={
+                  'fields': [{
+                      'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
+                  }]
+              },
+              create_disposition='CREATE_NEVER',
+              method='STREAMING_INSERTS',
+              insert_retry_strategy=RetryStrategy.RETRY_NEVER))
+      failed_values = (
+          bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]
+          | beam.Map(lambda x: x[1]['columnA']))
+
+      assert_that(failed_values, equal_to(['value1']))
+
+    self.assertEqual(1, mock_send.call_count)
+
+  # Running tests with a variety of error reasons from
+  # https://cloud.google.com/bigquery/docs/error-messages
+  # This covers the scenario when
+  # the google.cloud.bigquery.Client.insert_rows_json call returns an error 
list
+  # rather than raising an exception.
+  # Choosing some error reasons that are included in
+  # bigquery_tools._NON_TRANSIENT_ERRORS and some that are not
+  @parameterized.expand([
+      # reason in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
       param(
-          exception_type=exceptions.Unknown if exceptions else None,
-          error_args=('unknown!', 'args')),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          failed_rows=['value1'],
+          streaming=False),
+      # reason not in _NON_TRANSIENT_ERRORS for row 1 on 1st attempt
+      # transient error succeeds on 2nd attempt, 0 rows sent to failed rows
       param(
-          exception_type=exceptions.Aborted if exceptions else None,
-          error_args=('abortet!', 'abort')),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+              [],
+          ],
+          failed_rows=[],
+          streaming=False),
+      # reason in _NON_TRANSIENT_ERRORS for row 1
+      # reason not in _NON_TRANSIENT_ERRORS for row 2 on 1st and 2nd attempt
+      # all rows with errors are retried when any row has a retriable error
+      # row 1 sent to failed_rows after final attempt
       param(
-          exception_type=requests.exceptions.ConnectionError,
-          error_args=('some connection error', )),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }, {
+                  'index': 1, 'errors': [{
+                      'reason': 'internalError'
+                  }]
+              }],
+              [
+                  {
+                      'index': 0, 'errors': [{
+                          'reason': 'invalid'
+                      }]
+                  },
+              ],
+          ],
+          failed_rows=['value1'],
+          streaming=False),
+      # reason in _NON_TRANSIENT_ERRORS for row 1, sent to failed_rows
       param(
-          exception_type=requests.exceptions.Timeout,
-          error_args=('some timeout error', )),
+          insert_response=[
+              [{
+                  'index': 0, 'errors': [{
+                      'reason': 'invalid'
+                  }]
+              }],
+          ],
+          failed_rows=['value1'],
+          streaming=True),

Review Comment:
   I noticed that there were some separate tests, some with streaming 
true/false since the behavior with pipeline errors is different. I added them 
to demonstrate/check behavior for streaming true/false. Let me know if you 
think that's unnecessary, I can remove one or the other.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to