ahmedabu98 commented on code in PR #38015:
URL: https://github.com/apache/beam/pull/38015#discussion_r3017698577


##########
sdks/python/apache_beam/io/gcp/bigquery_change_history.py:
##########
@@ -808,19 +873,27 @@ def process(
           '[Read] try_claim(%d) succeeded: reading stream %s', i, stream_name)
 
       stream_rows = 0
-      for row in self._read_stream(stream_name):
-        ts = row.get(self._change_timestamp_column)
-        if ts is None:
-          raise ValueError(
-              'Row missing %r column. Row keys: %s' %
-              (self._change_timestamp_column, list(row.keys())))
-        if isinstance(ts, datetime.datetime):
-          ts = Timestamp.from_utc_datetime(ts)
-
-        yield TimestampedValue(row, ts)
-        stream_rows += 1
-        total_rows += 1
-      Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(total_rows)
+      if self._emit_raw_batches:
+        stream_batches = 0
+        for raw_batch in self._read_stream_raw(stream_name):
+          yield TimestampedValue(raw_batch, element.range_start)
+          stream_batches += 1
+        Metrics.counter('BigQueryChangeHistory',
+                        'batches_emitted').inc(stream_batches)
+      else:
+        for row in self._read_stream(stream_name):
+          ts = row.get(self._change_timestamp_column)
+          if ts is None:
+            raise ValueError(
+                'Row missing %r column. Row keys: %s' %
+                (self._change_timestamp_column, list(row.keys())))
+          if isinstance(ts, datetime.datetime):
+            ts = Timestamp.from_utc_datetime(ts)
+
+          yield TimestampedValue(row, ts)
+          stream_rows += 1
+        Metrics.counter('BigQueryChangeHistory',
+                        'rows_emitted').inc(stream_rows)

Review Comment:
   Should we use different output tags for each case?



##########
sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py:
##########
@@ -567,6 +567,56 @@ def check_rows(actual):
 
       assert_that(rows, check_rows)
 
+  def test_public_api_reads_inserted_row_with_fanout(self):
+    """ReadBigQueryChangeHistory PTransform with polling SDF."""
+    table_str = f'{self.project}:{self.dataset}.{self.test_table_id}'
+    start_time = self.insert_time - 120  # 2 min before insert
+    stop_time = time.time() + 15
+
+    with beam.Pipeline(argv=self.args) as p:
+      rows = (
+          p
+          | ReadBigQueryChangeHistory(
+              table=table_str,
+              poll_interval_sec=15,
+              start_time=start_time,
+              stop_time=stop_time,
+              change_function='APPENDS',
+              buffer_sec=10,
+              project=self.project,
+              temp_dataset=self.temp_dataset,
+              location=self.location,
+              decompress_shards=3))
+
+      def check_rows(actual):
+        assert len(actual) == 3, f'Expected 3 rows, got {len(actual)}'
+        got = sorted([{
+            k: v
+            for k, v in row.items() if k != 'change_timestamp'
+        } for row in actual],
+                     key=lambda r: r['id'])
+        expected = [
+            {
+                'id': 1,
+                'name': 'alice',
+                'value': 10.0,
+                'change_type': 'INSERT'
+            },
+            {
+                'id': 2, 'name': 'bob', 'value': 20.0, 'change_type': 'INSERT'
+            },
+            {
+                'id': 3,
+                'name': 'charlie',
+                'value': 30.0,
+                'change_type': 'INSERT'
+            },
+        ]

Review Comment:
   nit: this is duplicated from the previous test. maybe make it a class 
variable for readability



##########
sdks/python/apache_beam/io/gcp/bigquery_change_history.py:
##########
@@ -730,16 +738,75 @@ def _ensure_client(self) -> None:
   def setup(self) -> None:
     self._ensure_client()
 
+  def _split_all_streams(
+      self, stream_names: Tuple[str, ...],
+      max_split_rounds: int) -> Tuple[str, ...]:
+    """Split each stream at fraction=0.5 for up to max_split_rounds rounds.
+
+    Each round attempts to split every stream in the current list. A
+    successful split replaces the original stream with primary + remainder.
+    A refused split (both fields empty) keeps the original stream intact.
+    Stops when max_split_rounds is reached or a full round produces zero
+    new splits.
+
+    BQ's server-side granularity controls how many splits are possible.
+    Small tables may not split at all; large tables may allow multiple
+    rounds of doubling.
+    """
+    result = list(stream_names)
+    for round_num in range(1, max_split_rounds + 1):
+      new_result = []
+      made_progress = False
+      for name in result:
+        response = self._storage_client.split_read_stream(
+            request=bq_storage.types.SplitReadStreamRequest(
+                name=name, fraction=0.5))
+        primary = response.primary_stream.name
+        remainder = response.remainder_stream.name
+        if primary and remainder:
+          new_result.extend([primary, remainder])
+          made_progress = True
+        else:
+          new_result.append(name)

Review Comment:
   Should we keep another "no further splits" set, for streams that have maxed 
out their splits? We can skip requesting a split for those streams



##########
sdks/python/apache_beam/io/gcp/bigquery_change_history.py:
##########
@@ -1170,16 +1315,46 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> 
beam.PCollection:
                 row_filter=self._row_filter))
         | 'CommitQueryResults' >> beam.Reshuffle())
 
+    emit_raw = self._decompress_shards is not None
+
+    read_sdf = beam.ParDo(
+        _ReadStorageStreamsSDF(
+            batch_arrow_read=self._batch_arrow_read,
+            change_timestamp_column=self._change_timestamp_column,
+            max_split_rounds=self._max_split_rounds,
+            emit_raw_batches=emit_raw))
+    if emit_raw:
+      read_sdf = read_sdf.with_output_types(Tuple[bytes, bytes])
+    else:
+      read_sdf = read_sdf.with_output_types(Dict[str, Any])
+
     read_outputs = (
         query_results
-        | 'ReadStorageStreams' >> beam.ParDo(
-            _ReadStorageStreamsSDF(
-                batch_arrow_read=self._batch_arrow_read,
-                change_timestamp_column=self._change_timestamp_column)).
-        with_outputs(_CLEANUP_TAG, main='rows'))
+        | 'ReadStorageStreams' >> read_sdf.with_outputs(
+            _CLEANUP_TAG, main='rows'))
 
     _ = (
         read_outputs[_CLEANUP_TAG]
         | 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn()))
 
-    return read_outputs['rows']
+    if emit_raw:
+      # Fan out raw Arrow batches across decompress_shards workers
+      # via GBK, then decompress and convert to timestamped row dicts.
+      # Uses a discarding trigger so GBK fires per-element without
+      # waiting for the GlobalWindow to close.
+      num_shards = self._decompress_shards
+      rows = (
+          read_outputs['rows']
+          | 'ShardBatches' >>
+          beam.WithKeys(lambda _, n=num_shards: random.randint(0, n - 1))
+          | 'WindowForGBK' >> beam.WindowInto(
+              GlobalWindows(),
+              trigger=beam_trigger.Repeatedly(beam_trigger.AfterCount(1)),
+              accumulation_mode=(beam_trigger.AccumulationMode.DISCARDING))
+          | 'GroupByShardKey' >> beam.GroupByKey()

Review Comment:
   Should we use GroupIntoBatches.WithShardedKey instead?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to