This is an automated email from the ASF dual-hosted git repository.
ningk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 3b7cb10 [BEAM-10708] Added beam_sql magics
new 5a627eb Merge pull request #15368 from KevinGG/sql
3b7cb10 is described below
commit 3b7cb10789ab13e2b1ec458b5e0f5cb6e4f867b5
Author: KevinGG <[email protected]>
AuthorDate: Fri Aug 20 14:16:23 2021 -0700
[BEAM-10708] Added beam_sql magics
Added a beam_sql cell magic that applies a SqlTransform based on a
given Beam SQL query in the notebook.
---
.../runners/interactive/sql/__init__.py | 16 ++
.../runners/interactive/sql/beam_sql_magics.py | 293 +++++++++++++++++++++
.../interactive/sql/beam_sql_magics_test.py | 121 +++++++++
.../apache_beam/runners/interactive/sql/utils.py | 125 +++++++++
.../runners/interactive/sql/utils_test.py | 90 +++++++
.../apache_beam/runners/interactive/utils.py | 10 +
sdks/python/scripts/generate_pydoc.sh | 3 +
sdks/python/setup.py | 2 +-
8 files changed, 659 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/runners/interactive/sql/__init__.py
b/sdks/python/apache_beam/runners/interactive/sql/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
new file mode 100644
index 0000000..cee3d34
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
@@ -0,0 +1,293 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module of beam_sql cell magic that executes a Beam SQL.
+
+Only works within an IPython kernel.
+"""
+
+import importlib
+import keyword
+import logging
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.pvalue import PValue
+from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+from apache_beam.runners.interactive import pipeline_instrument as inst
+from apache_beam.runners.interactive.cache_manager import FileBasedCacheManager
+from apache_beam.runners.interactive.caching.streaming_cache import
StreamingCache
+from apache_beam.runners.interactive.sql.utils import find_pcolls
+from apache_beam.runners.interactive.sql.utils import is_namedtuple
+from apache_beam.runners.interactive.sql.utils import pcolls_by_name
+from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
+from apache_beam.runners.interactive.sql.utils import
replace_single_pcoll_token
+from apache_beam.runners.interactive.utils import obfuscate
+from apache_beam.runners.interactive.utils import progress_indicated
+from apache_beam.testing import test_stream
+from apache_beam.testing.test_stream_service import TestStreamServiceController
+from apache_beam.transforms.sql import SqlTransform
+from IPython.core.magic import Magics
+from IPython.core.magic import cell_magic
+from IPython.core.magic import magics_class
+
+_LOGGER = logging.getLogger(__name__)
+
+_EXAMPLE_USAGE = """Usage:
+ %%%%beam_sql [output_name]
+ Calcite SQL statement
+ Syntax:
https://beam.apache.org/documentation/dsls/sql/calcite/query-syntax/
+ Please make sure that there is no conflicts between your variable names and
+ the SQL keywords, such as "SELECT", "FROM", "WHERE" and etc.
+
+ output_name is optional. If not supplied, a variable name is automatically
+ assigned to the output of the magic.
+
+ The output of the magic is usually a PCollection or similar PValue,
+ depending on the SQL statement executed.
+"""
+
+
+def on_error(error_msg, *args):
+ """Logs the error and the usage example."""
+ _LOGGER.error(error_msg, *args)
+ _LOGGER.info(_EXAMPLE_USAGE)
+
+
+@magics_class
+class BeamSqlMagics(Magics):
+ @cell_magic
+ def beam_sql(self, line: str, cell: str) -> Union[None, PValue]:
+ """The beam_sql cell magic that executes a Beam SQL.
+
+ Args:
+ line: (optional) the string on the same line after the beam_sql magic.
+ Used as the output variable name in the __main__ module.
+ cell: everything else in the same notebook cell as a string. Used as a
+ Beam SQL query.
+
+ Returns None if running into an error, otherwise a PValue as if a
+ SqlTransform is applied.
+ """
+ if line and not line.strip().isidentifier() or keyword.iskeyword(
+ line.strip()):
+ on_error(
+ 'The output_name "%s" is not a valid identifier. Please supply a '
+ 'valid identifier that is not a Python keyword.',
+ line)
+ return
+ if not cell or cell.isspace():
+ on_error('Please supply the sql to be executed.')
+ return
+ found = find_pcolls(cell, pcolls_by_name())
+ for _, pcoll in found.items():
+ if not is_namedtuple(pcoll.element_type):
+ on_error(
+ 'PCollection %s of type %s is not a NamedTuple. See '
+ 'https://beam.apache.org/documentation/programming-guide/#schemas '
+ 'for more details.',
+ pcoll,
+ pcoll.element_type)
+ return
+ register_coder_for_schema(pcoll.element_type)
+
+ # TODO(BEAM-10708): implicitly execute the pipeline and write output into
+ # cache.
+ return apply_sql(cell, line, found)
+
+
+@progress_indicated
+def apply_sql(
+ query: str, output_name: Optional[str],
+ found: Dict[str, beam.PCollection]) -> PValue:
+ """Applies a SqlTransform with the given sql and queried PCollections.
+
+ Args:
+ query: The SQL query executed in the magic.
+ output_name: (optional) The output variable name in __main__ module.
+ found: The PCollections with variable names found to be used in the query.
+
+ Returns:
+ A PValue, mostly a PCollection, depending on the query.
+ """
+ output_name = _generate_output_name(output_name, query, found)
+ query, sql_source = _build_query_components(query, found)
+ try:
+ output = sql_source | SqlTransform(query)
+ # Declare a variable with the output_name and output value in the
+ # __main__ module so that the user can use the output smoothly.
+ setattr(importlib.import_module('__main__'), output_name, output)
+ ib.watch({output_name: output})
+ _LOGGER.info(
+ "The output PCollection variable is %s: %s", output_name, output)
+ return output
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception as e:
+ on_error('Error when applying the Beam SQL: %s', e)
+
+
+def pcoll_from_file_cache(
+ query_pipeline: beam.Pipeline,
+ pcoll: beam.PCollection,
+ cache_manager: FileBasedCacheManager,
+ key: str) -> beam.PCollection:
+ """Reads PCollection cache from files.
+
+ Args:
+ query_pipeline: The beam.Pipeline object built by the magic to execute the
+ SQL query.
+ pcoll: The PCollection to read cache for.
+ cache_manager: The file based cache manager that holds the PCollection
+ cache.
+ key: The key of the PCollection cache.
+
+ Returns:
+ A PCollection read from the cache.
+ """
+ schema = pcoll.element_type
+
+ class Unreify(beam.DoFn):
+ def process(self, e):
+ if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'):
+ yield e.windowed_value
+
+ return (
+ query_pipeline
+ |
+ '{}{}'.format('QuerySource', key) >> cache.ReadCache(cache_manager, key)
+ | '{}{}'.format('Unreify', key) >> beam.ParDo(
+ Unreify()).with_output_types(schema))
+
+
+def pcolls_from_streaming_cache(
+ user_pipeline: beam.Pipeline,
+ query_pipeline: beam.Pipeline,
+ name_to_pcoll: Dict[str, beam.PCollection],
+ instrumentation: inst.PipelineInstrument,
+ cache_manager: StreamingCache) -> Dict[str, beam.PCollection]:
+ """Reads PCollection cache through the TestStream.
+
+ Args:
+ user_pipeline: The beam.Pipeline object defined by the user in the
+ notebook.
+ query_pipeline: The beam.Pipeline object built by the magic to execute the
+ SQL query.
+ name_to_pcoll: PCollections with variable names used in the SQL query.
+ instrumentation: A pipeline_instrument.PipelineInstrument that helps
+ calculate the cache key of a given PCollection.
+ cache_manager: The streaming cache manager that holds the PCollection
cache.
+
+ Returns:
+ A Dict[str, beam.PCollection], where each PCollection is tagged with
+ their PCollection variable names, read from the cache.
+
+ When the user_pipeline has unbounded sources, we force all cache reads to go
+ through the TestStream even if they are bounded sources.
+ """
+ def exception_handler(e):
+ _LOGGER.error(str(e))
+ return True
+
+ test_stream_service = ie.current_env().get_test_stream_service_controller(
+ user_pipeline)
+ if not test_stream_service:
+ test_stream_service = TestStreamServiceController(
+ cache_manager, exception_handler=exception_handler)
+ test_stream_service.start()
+ ie.current_env().set_test_stream_service_controller(
+ user_pipeline, test_stream_service)
+
+ tag_to_name = {}
+ for name, pcoll in name_to_pcoll.items():
+ key = instrumentation.cache_key(pcoll)
+ tag_to_name[key] = name
+ output_pcolls = query_pipeline | test_stream.TestStream(
+ output_tags=set(tag_to_name.keys()),
+ coder=cache_manager._default_pcoder,
+ endpoint=test_stream_service.endpoint)
+ sql_source = {}
+ for tag, output in output_pcolls.items():
+ sql_source[tag_to_name[tag]] = output
+ return sql_source
+
+
+def _generate_output_name(
+ output_name: Optional[str], query: str,
+ found: Dict[str, beam.PCollection]) -> str:
+ """Generates a unique output name if None is provided.
+
+ Otherwise, returns the given output name directly.
+ The generated output name is sql_output_{uuid} where uuid is an obfuscated
+ value from the query and PCollections found to be used in the query.
+ """
+ if not output_name:
+ execution_id = obfuscate(query, found)[:12]
+ output_name = 'sql_output_' + execution_id
+ return output_name
+
+
+def _build_query_components(
+ query: str, found: Dict[str, beam.PCollection]
+) -> Tuple[str,
+ Union[Dict[str, beam.PCollection], beam.PCollection,
beam.Pipeline]]:
+ """Builds necessary components needed to apply the SqlTransform.
+
+ Args:
+ query: The SQL query to be executed by the magic.
+ found: The PCollections with variable names found to be used by the query.
+
+ Returns:
+ The processed query to be executed by the magic and a source to apply the
+ SqlTransform to: a dictionary of tagged PCollections, or a single
+ PCollection, or the pipeline to execute the query.
+ """
+ if found:
+ user_pipeline = next(iter(found.values())).pipeline
+ cache_manager = ie.current_env().get_cache_manager(user_pipeline)
+ instrumentation = inst.build_pipeline_instrument(user_pipeline)
+ sql_pipeline = beam.Pipeline(options=user_pipeline._options)
+ ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
+ sql_source = {}
+ if instrumentation.has_unbounded_sources:
+ sql_source = pcolls_from_streaming_cache(
+ user_pipeline, sql_pipeline, found, instrumentation, cache_manager)
+ else:
+ for pcoll_name, pcoll in found.items():
+ cache_key = instrumentation.cache_key(pcoll)
+ sql_source[pcoll_name] = pcoll_from_file_cache(
+ sql_pipeline, pcoll, cache_manager, cache_key)
+ if len(sql_source) == 1:
+ query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
+ sql_source = next(iter(sql_source.values()))
+ else:
+ sql_source = beam.Pipeline()
+ return query, sql_source
+
+
+def load_ipython_extension(ipython):
+ """Marks this module as an IPython extension.
+
+ To load this magic in an IPython environment, execute:
+ %load_ext apache_beam.runners.interactive.sql.beam_sql_magics.
+ """
+ ipython.register_magics(BeamSqlMagics)
diff --git
a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
new file mode 100644
index 0000000..d35bd46
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
@@ -0,0 +1,121 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for beam_sql_magics module."""
+
+# pytype: skip-file
+
+import unittest
+from unittest.mock import patch
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+
+try:
+ from apache_beam.runners.interactive.sql.beam_sql_magics import
_build_query_components
+ from apache_beam.runners.interactive.sql.beam_sql_magics import
_generate_output_name
+except (ImportError, NameError):
+ pass # The test is to be skipped because [interactive] dep not installed.
+
+
[email protected](
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
[email protected](
+ not ie.current_env().is_interactive_ready,
+ reason='[interactive] dependency is not installed.')
+class BeamSqlMagicsTest(unittest.TestCase):
+ def test_generate_output_name_when_not_provided(self):
+ output_name = None
+ self.assertTrue(
+ _generate_output_name(output_name, '', {}).startswith('sql_output_'))
+
+ def test_use_given_output_name_when_provided(self):
+ output_name = 'output'
+ self.assertEqual(_generate_output_name(output_name, '', {}), output_name)
+
+ def test_build_query_components_when_no_pcoll_queried(self):
+ query = """SELECT CAST(1 AS INT) AS `id`,
+ CAST('foo' AS VARCHAR) AS `str`,
+ CAST(3.14 AS DOUBLE) AS `flt`"""
+ processed_query, sql_source = _build_query_components(query, {})
+ self.assertEqual(processed_query, query)
+ self.assertIsInstance(sql_source, beam.Pipeline)
+
+ def test_build_query_components_when_single_pcoll_queried(self):
+ p = beam.Pipeline()
+ target = p | beam.Create([1, 2, 3])
+ ib.watch(locals())
+ query = 'SELECT * FROM target where a=1'
+ found = {'target': target}
+
+ with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+ 'pcoll_from_file_cache',
+ lambda a,
+ b,
+ c,
+ d: target):
+ processed_query, sql_source = _build_query_components(query, found)
+
+ self.assertEqual(processed_query, 'SELECT * FROM PCOLLECTION where a=1')
+ self.assertIsInstance(sql_source, beam.PCollection)
+
+ def test_build_query_components_when_multiple_pcolls_queried(self):
+ p = beam.Pipeline()
+ pcoll_1 = p | 'Create 1' >> beam.Create([1, 2, 3])
+ pcoll_2 = p | 'Create 2' >> beam.Create([4, 5, 6])
+ ib.watch(locals())
+ query = 'SELECT * FROM pcoll_1 JOIN pcoll_2 USING (a)'
+ found = {'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2}
+
+ with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+ 'pcoll_from_file_cache',
+ lambda a,
+ b,
+ c,
+ d: pcoll_1):
+ processed_query, sql_source = _build_query_components(query, found)
+
+ self.assertEqual(processed_query, query)
+ self.assertIsInstance(sql_source, dict)
+ self.assertIn('pcoll_1', sql_source)
+ self.assertIn('pcoll_2', sql_source)
+
+ def test_build_query_components_when_unbounded_pcolls_queried(self):
+ p = beam.Pipeline()
+ pcoll = p | beam.io.ReadFromPubSub(
+ subscription='projects/fake-project/subscriptions/fake_sub')
+ ib.watch(locals())
+ query = 'SELECT * FROM pcoll'
+ found = {'pcoll': pcoll}
+
+ with patch('apache_beam.runners.interactive.sql.beam_sql_magics.'
+ 'pcolls_from_streaming_cache',
+ lambda a,
+ b,
+ c,
+ d,
+ e: found):
+ _, sql_source = _build_query_components(query, found)
+ self.assertIs(sql_source, pcoll)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils.py
b/sdks/python/apache_beam/runners/interactive/sql/utils.py
new file mode 100644
index 0000000..355b6e6
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils.py
@@ -0,0 +1,125 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module of utilities for SQL magics.
+
+For internal use only; no backward-compatibility guarantees.
+"""
+
+# pytype: skip-file
+
+import logging
+from typing import Dict
+from typing import NamedTuple
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive import interactive_environment as ie
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def is_namedtuple(cls: type) -> bool:
+ """Determines if a class is built from typing.NamedTuple."""
+ return (
+ isinstance(cls, type) and issubclass(cls, tuple) and
+ hasattr(cls, '_fields') and hasattr(cls, '_field_types'))
+
+
+def register_coder_for_schema(schema: NamedTuple) -> None:
+ """Registers a RowCoder for the given schema if hasn't.
+
+ Notifies the user of what code has been implicitly executed.
+ """
+ assert is_namedtuple(schema), (
+ 'Schema %s is not a typing.NamedTuple.' % schema)
+ coder = beam.coders.registry.get_coder(schema)
+ if not isinstance(coder, beam.coders.RowCoder):
+ _LOGGER.warning(
+ 'Schema %s has not been registered to use a RowCoder. '
+ 'Automatically registering it by running: '
+ 'beam.coders.registry.register_coder(%s, '
+ 'beam.coders.RowCoder)',
+ schema.__name__,
+ schema.__name__)
+ beam.coders.registry.register_coder(schema, beam.coders.RowCoder)
+
+
+def pcolls_by_name() -> Dict[str, beam.PCollection]:
+ """Finds all PCollections by their variable names defined in the notebook."""
+ inspectables = ie.current_env().inspector.inspectables
+ pcolls = {}
+ for _, inspectable in inspectables.items():
+ metadata = inspectable['metadata']
+ if metadata['type'] == 'pcollection':
+ pcolls[metadata['name']] = inspectable['value']
+ return pcolls
+
+
+def find_pcolls(
+ sql: str, pcolls: Dict[str,
+ beam.PCollection]) -> Dict[str, beam.PCollection]:
+ """Finds all PCollections used in the given sql query.
+
+ It does a simple word by word match and calls ib.collect for each PCollection
+ found.
+ """
+ found = {}
+ for word in sql.split():
+ if word in pcolls:
+ found[word] = pcolls[word]
+ if found:
+ _LOGGER.info('Found PCollections used in the magic: %s.', found)
+ _LOGGER.info('Collecting data...')
+ for name, pcoll in found.items():
+ try:
+ _ = ib.collect(pcoll)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ _LOGGER.error(
+ 'Cannot collect data for PCollection %s. Please make sure the '
+ 'PCollections queried in the sql "%s" are all from a single '
+ 'pipeline using an InteractiveRunner. Make sure there is no '
+ 'ambiguity, for example, same named PCollections from multiple '
+ 'pipelines or notebook re-executions.',
+ name,
+ sql)
+ raise
+ _LOGGER.info('Done collecting data.')
+ return found
+
+
+def replace_single_pcoll_token(sql: str, pcoll_name: str) -> str:
+ """Replaces the pcoll_name used in the sql with 'PCOLLECTION'.
+
+ For sql query using only a single PCollection, the PCollection needs to be
+ referred to as 'PCOLLECTION' instead of its variable/tag name.
+ """
+ words = sql.split()
+ token_locations = []
+ i = 0
+ for word in words:
+ if word.lower() == 'from':
+ token_locations.append(i + 1)
+ i += 2
+ continue
+ i += 1
+ for token_location in token_locations:
+ if token_location < len(words) and words[token_location] == pcoll_name:
+ words[token_location] = 'PCOLLECTION'
+ return ' '.join(words)
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
new file mode 100644
index 0000000..ed52cad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for utils module."""
+
+# pytype: skip-file
+
+import unittest
+from typing import NamedTuple
+from unittest.mock import patch
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.runners.interactive.sql.utils import find_pcolls
+from apache_beam.runners.interactive.sql.utils import is_namedtuple
+from apache_beam.runners.interactive.sql.utils import pcolls_by_name
+from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
+from apache_beam.runners.interactive.sql.utils import
replace_single_pcoll_token
+
+
+class ANamedTuple(NamedTuple):
+ a: int
+ b: str
+
+
+class UtilsTest(unittest.TestCase):
+ def test_is_namedtuple(self):
+ class AType:
+ pass
+
+ a_type = AType
+ a_tuple = type((1, 2, 3))
+
+ a_namedtuple = ANamedTuple
+
+ self.assertTrue(is_namedtuple(a_namedtuple))
+ self.assertFalse(is_namedtuple(a_type))
+ self.assertFalse(is_namedtuple(a_tuple))
+
+ def test_register_coder_for_schema(self):
+ self.assertNotIsInstance(
+ beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
+ register_coder_for_schema(ANamedTuple)
+ self.assertIsInstance(
+ beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
+
+ def test_pcolls_by_name(self):
+ p = beam.Pipeline()
+ pcoll = p | beam.Create([1])
+ ib.watch({'p': p, 'pcoll': pcoll})
+
+ name_to_pcoll = pcolls_by_name()
+ self.assertIn('pcoll', name_to_pcoll)
+
+ def test_find_pcolls(self):
+ with patch('apache_beam.runners.interactive.interactive_beam.collect',
+ lambda _: None):
+ found = find_pcolls(
+ """SELECT * FROM pcoll_1 JOIN pcoll_2
+ USING (common_column)""", {
+ 'pcoll_1': None, 'pcoll_2': None
+ })
+ self.assertIn('pcoll_1', found)
+ self.assertIn('pcoll_2', found)
+
+ def test_replace_single_pcoll_token(self):
+ sql = 'SELECT * FROM abc WHERE a=1 AND b=2'
+ replaced_sql = replace_single_pcoll_token(sql, 'wow')
+ self.assertEqual(replaced_sql, sql)
+ replaced_sql = replace_single_pcoll_token(sql, 'abc')
+ self.assertEqual(
+ replaced_sql, 'SELECT * FROM PCOLLECTION WHERE a=1 AND b=2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py
b/sdks/python/apache_beam/runners/interactive/utils.py
index 3e85145..cb0b7db 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -34,6 +34,15 @@ from apache_beam.typehints.schemas import
named_fields_from_element_type
_LOGGER = logging.getLogger(__name__)
+# Add line breaks to the IPythonLogHandler's HTML output.
+_INTERACTIVE_LOG_STYLE = """
+ <style>
+ div.alert {
+ white-space: pre-line;
+ }
+ </style>
+"""
+
def to_element_list(
reader, # type: Generator[Union[TestStreamPayload.Event,
WindowedValueHolder]]
@@ -169,6 +178,7 @@ class IPythonLogHandler(logging.Handler):
from html import escape
from IPython.core.display import HTML
from IPython.core.display import display
+ display(HTML(_INTERACTIVE_LOG_STYLE))
display(
HTML(
self.log_template.format(
diff --git a/sdks/python/scripts/generate_pydoc.sh
b/sdks/python/scripts/generate_pydoc.sh
index 6b4b344..fb0c415 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -218,6 +218,9 @@ ignore_identifiers = [
'google.cloud.datastore.batch.Batch',
'is_in_ipython',
'doctest.TestResults',
+
+ # IPython Magics py:class reference target not found
+ 'IPython.core.magic.Magics',
]
ignore_references = [
'BeamIOError',
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 170fbff..83826e1 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -201,7 +201,7 @@ GCP_REQUIREMENTS = [
INTERACTIVE_BEAM = [
'facets-overview>=1.0.0,<2',
- 'ipython>=5.8.0,<8',
+ 'ipython>=7,<8',
'ipykernel>=5.2.0,<6',
# Skip version 6.1.13 due to
# https://github.com/jupyter/jupyter_client/issues/637