You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ni...@apache.org on 2021/10/05 21:37:10 UTC
[beam] branch master updated: [BEAM-10708] Enable submit beam_sql
built jobs to Dataflow
This is an automated email from the ASF dual-hosted git repository.
ningk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 8c0601c [BEAM-10708] Enable submit beam_sql built jobs to Dataflow
new bed6bee Merge pull request #15647 from KevinGG/beam_sql_on_df
8c0601c is described below
commit 8c0601cca39f8350be51084bb98ea53f90c5466a
Author: KevinGG <ka...@gmail.com>
AuthorDate: Fri Oct 1 11:51:14 2021 -0700
[BEAM-10708] Enable submit beam_sql built jobs to Dataflow
1. Added an additional beam_sql option to specify a runner.
2. Added a sql_chain module to track chained beam_sql magics and produce
pipelines for execution when a non-direct runner is specified.
3. Added logic to load schemas defined in main session without relying
on save_main_session that might fail.
4. Added a OptionsForm class and DataflowOptionsForm subclass to guide
users through pipeline options configuration in notebooks.
5. Removed is_namedtuple utility and honor the Beam common utility
match_is_named_tuple. Note dill does not preserve __annotations__
across multiple main sessions. Added a workaround until cloudpickle
replaces dill in Beam.
---
.../runners/interactive/interactive_environment.py | 17 +
.../interactive/interactive_environment_test.py | 32 ++
.../runners/interactive/sql/beam_sql_magics.py | 187 ++++++++---
.../interactive/sql/beam_sql_magics_test.py | 28 +-
.../runners/interactive/sql/sql_chain.py | 226 +++++++++++++
.../runners/interactive/sql/sql_chain_test.py | 109 +++++++
.../apache_beam/runners/interactive/sql/utils.py | 354 +++++++++++++++++++--
.../runners/interactive/sql/utils_test.py | 50 ++-
.../apache_beam/runners/interactive/utils.py | 22 ++
.../apache_beam/runners/interactive/utils_test.py | 8 +
sdks/python/setup.py | 1 +
11 files changed, 941 insertions(+), 93 deletions(-)
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
index fe10ab1..9f3d66e 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py
@@ -36,6 +36,7 @@ from apache_beam.runners import runner
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive.messaging.interactive_environment_inspector import InteractiveEnvironmentInspector
from apache_beam.runners.interactive.recording_manager import RecordingManager
+from apache_beam.runners.interactive.sql.sql_chain import SqlChain
from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker
from apache_beam.runners.interactive.utils import register_ipython_log_handler
from apache_beam.utils.interactive_utils import is_in_ipython
@@ -206,6 +207,8 @@ class InteractiveEnvironment(object):
self._inspector_with_synthetic = InteractiveEnvironmentInspector(
ignore_synthetic=False)
+ self.sql_chain = {}
+
@property
def options(self):
"""A reference to the global interactive options.
@@ -651,3 +654,17 @@ class InteractiveEnvironment(object):
Javascript(_HTML_IMPORT_TEMPLATE.format(hrefs=html_hrefs)))
except ImportError:
pass # NOOP if dependencies are not available.
+
+ def get_sql_chain(self, pipeline, set_user_pipeline=False):
+ if pipeline not in self.sql_chain:
+ self.sql_chain[pipeline] = SqlChain()
+ chain = self.sql_chain[pipeline]
+ if set_user_pipeline:
+ if chain.user_pipeline and chain.user_pipeline is not pipeline:
+ raise ValueError(
+ 'The beam_sql magic tries to query PCollections from multiple '
+ 'pipelines: %s and %s',
+ chain.user_pipeline,
+ pipeline)
+ chain.user_pipeline = pipeline
+ return chain
diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
index f08db01..4e0293d 100644
--- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
+++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py
@@ -27,6 +27,7 @@ from apache_beam.runners import runner
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.recording_manager import RecordingManager
+from apache_beam.runners.interactive.sql.sql_chain import SqlNode
# The module name is also a variable in module.
_module_name = 'apache_beam.runners.interactive.interactive_environment_test'
@@ -303,6 +304,37 @@ class InteractiveEnvironmentTest(unittest.TestCase):
expected_description = {p1: rm1.describe(), p2: rm2.describe()}
self.assertDictEqual(description, expected_description)
+ def test_get_empty_sql_chain(self):
+ env = ie.InteractiveEnvironment()
+ p = beam.Pipeline()
+ chain = env.get_sql_chain(p)
+ self.assertIsNotNone(chain)
+ self.assertEqual(chain.nodes, {})
+
+ def test_get_sql_chain_with_nodes(self):
+ env = ie.InteractiveEnvironment()
+ p = beam.Pipeline()
+ chain_with_node = env.get_sql_chain(p).append(
+ SqlNode(output_name='name', source=p, query="query"))
+ chain_got = env.get_sql_chain(p)
+ self.assertIs(chain_with_node, chain_got)
+
+ def test_get_sql_chain_setting_user_pipeline(self):
+ env = ie.InteractiveEnvironment()
+ p = beam.Pipeline()
+ chain = env.get_sql_chain(p, set_user_pipeline=True)
+ self.assertIs(chain.user_pipeline, p)
+
+ def test_get_sql_chain_None_when_setting_multiple_user_pipelines(self):
+ env = ie.InteractiveEnvironment()
+ p = beam.Pipeline()
+ chain = env.get_sql_chain(p, set_user_pipeline=True)
+ p2 = beam.Pipeline()
+ # Set the chain for a different pipeline.
+ env.sql_chain[p2] = chain
+ with self.assertRaises(ValueError):
+ env.get_sql_chain(p2, set_user_pipeline=True)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
index bd40f13..d27fc61 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics.py
@@ -32,24 +32,27 @@ from typing import Union
import apache_beam as beam
from apache_beam.pvalue import PValue
-from apache_beam.runners.interactive import interactive_beam as ib
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.background_caching_job import has_source_to_cache
from apache_beam.runners.interactive.caching.cacheable import CacheKey
from apache_beam.runners.interactive.caching.reify import reify_to_cache
from apache_beam.runners.interactive.caching.reify import unreify_from_cache
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
+from apache_beam.runners.interactive.sql.sql_chain import SqlChain
+from apache_beam.runners.interactive.sql.sql_chain import SqlNode
+from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
from apache_beam.runners.interactive.sql.utils import find_pcolls
-from apache_beam.runners.interactive.sql.utils import is_namedtuple
from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
+from apache_beam.runners.interactive.utils import create_var_in_main
from apache_beam.runners.interactive.utils import obfuscate
from apache_beam.runners.interactive.utils import pcoll_by_name
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.testing import test_stream
from apache_beam.testing.test_stream_service import TestStreamServiceController
from apache_beam.transforms.sql import SqlTransform
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
from IPython.core.magic import Magics
from IPython.core.magic import line_cell_magic
from IPython.core.magic import magics_class
@@ -58,11 +61,11 @@ _LOGGER = logging.getLogger(__name__)
_EXAMPLE_USAGE = """beam_sql magic to execute Beam SQL in notebooks
---------------------------------------------------------
-%%beam_sql [-o OUTPUT_NAME] query
+%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query
---------------------------------------------------------
Or
---------------------------------------------------------
-%%%%beam_sql [-o OUTPUT_NAME] query-line#1
+%%%%beam_sql [-o OUTPUT_NAME] [-v] [-r RUNNER] query-line#1
query-line#2
...
query-line#N
@@ -82,6 +85,8 @@ _NOT_SUPPORTED_MSG = """The query was valid and successfully applied.
to build Beam pipelines in a non-interactive manner.
"""
+_SUPPORTED_RUNNERS = ['DirectRunner', 'DataflowRunner']
+
class BeamSqlParser:
"""A parser to parse beam_sql inputs."""
@@ -100,6 +105,14 @@ class BeamSqlParser:
action='store_true',
help='Display more details about the magic execution.')
self._parser.add_argument(
+ '-r',
+ '--runner',
+ dest='runner',
+ help=(
+ 'The runner to run the query. Supported runners are %s. If not '
+ 'provided, DirectRunner is used and results can be inspected '
+ 'locally.' % _SUPPORTED_RUNNERS))
+ self._parser.add_argument(
'query',
type=str,
nargs='*',
@@ -157,8 +170,9 @@ class BeamSqlMagics(Magics):
cell: everything else in the same notebook cell as a string. If None,
beam_sql is used as line magic. Otherwise, cell magic.
- Returns None if running into an error, otherwise a PValue as if a
- SqlTransform is applied.
+ Returns None if running into an error or waiting for user input (running on
+ a selected runner remotely), otherwise a PValue as if a SqlTransform is
+ applied.
"""
input_str = line
if cell:
@@ -170,6 +184,7 @@ class BeamSqlMagics(Magics):
output_name = parsed.output_name
verbose = parsed.verbose
query = parsed.query
+ runner = parsed.runner
if output_name and not output_name.isidentifier() or keyword.iskeyword(
output_name):
@@ -181,11 +196,18 @@ class BeamSqlMagics(Magics):
if not query:
on_error('Please supply the SQL query to be executed.')
return
+ if runner and runner not in _SUPPORTED_RUNNERS:
+ on_error(
+ 'Runner "%s" is not supported. Supported runners are %s.',
+ runner,
+ _SUPPORTED_RUNNERS)
query = ' '.join(query)
found = find_pcolls(query, pcoll_by_name(), verbose=verbose)
+ schemas = set()
+ main_session = importlib.import_module('__main__')
for _, pcoll in found.items():
- if not is_namedtuple(pcoll.element_type):
+ if not match_is_named_tuple(pcoll.element_type):
on_error(
'PCollection %s of type %s is not a NamedTuple. See '
'https://beam.apache.org/documentation/programming-guide/#schemas '
@@ -194,45 +216,93 @@ class BeamSqlMagics(Magics):
pcoll.element_type)
return
register_coder_for_schema(pcoll.element_type, verbose=verbose)
+ # Only care about schemas defined by the user in the main module.
+ if hasattr(main_session, pcoll.element_type.__name__):
+ schemas.add(pcoll.element_type)
+
+ if runner in ('DirectRunner', None):
+ collect_data_for_local_run(query, found)
+ output_name, output, chain = apply_sql(query, output_name, found)
+ chain.current.schemas = schemas
+ cache_output(output_name, output)
+ return output
+
+ output_name, current_node, chain = apply_sql(
+ query, output_name, found, False)
+ current_node.schemas = schemas
+ # TODO(BEAM-10708): Move the options setup and result handling to a
+ # separate module when more runners are supported.
+ if runner == 'DataflowRunner':
+ _ = chain.to_pipeline()
+ _ = DataflowOptionsForm(
+ output_name, pcoll_by_name()[output_name],
+ verbose).display_for_input()
+ return None
+ else:
+ raise ValueError('Unsupported runner %s.', runner)
- output_name, output = apply_sql(query, output_name, found)
- cache_output(output_name, output)
- return output
+
+@progress_indicated
+def collect_data_for_local_run(query: str, found: Dict[str, beam.PCollection]):
+ from apache_beam.runners.interactive import interactive_beam as ib
+ for name, pcoll in found.items():
+ try:
+ _ = ib.collect(pcoll)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ _LOGGER.error(
+ 'Cannot collect data for PCollection %s. Please make sure the '
+ 'PCollections queried in the sql "%s" are all from a single '
+ 'pipeline using an InteractiveRunner. Make sure there is no '
+ 'ambiguity, for example, same named PCollections from multiple '
+ 'pipelines or notebook re-executions.',
+ name,
+ query)
+ raise
@progress_indicated
def apply_sql(
- query: str, output_name: Optional[str],
- found: Dict[str, beam.PCollection]) -> Tuple[str, PValue]:
+ query: str,
+ output_name: Optional[str],
+ found: Dict[str, beam.PCollection],
+ run: bool = True) -> Tuple[str, Union[PValue, SqlNode], SqlChain]:
"""Applies a SqlTransform with the given sql and queried PCollections.
Args:
query: The SQL query executed in the magic.
output_name: (optional) The output variable name in __main__ module.
found: The PCollections with variable names found to be used in the query.
+ run: Whether to prepare the SQL pipeline for a local run or not.
Returns:
- A Tuple[str, PValue]. First str value is the output variable name in
- __main__ module (auto-generated if not provided). Second PValue is
- most likely a PCollection, depending on the query.
+ A tuple of values. First str value is the output variable name in
+ __main__ module, auto-generated if not provided. Second value: if run,
+ it's a PValue; otherwise, a SqlNode tracks the SQL without applying it or
+ executing it. Third value: SqlChain is a chain of SqlNodes that have been
+ applied.
"""
output_name = _generate_output_name(output_name, query, found)
- query, sql_source = _build_query_components(query, found)
- try:
- output = sql_source | SqlTransform(query)
- # Declare a variable with the output_name and output value in the
- # __main__ module so that the user can use the output smoothly.
- setattr(importlib.import_module('__main__'), output_name, output)
- ib.watch({output_name: output})
- _LOGGER.info(
- "The output PCollection variable is %s with element_type %s",
- output_name,
- pformat_namedtuple(output.element_type))
- return output_name, output
- except (KeyboardInterrupt, SystemExit):
- raise
- except Exception as e:
- on_error('Error when applying the Beam SQL: %s', e)
+ query, sql_source, chain = _build_query_components(
+ query, found, output_name, run)
+ if run:
+ try:
+ output = sql_source | SqlTransform(query)
+ # Declare a variable with the output_name and output value in the
+ # __main__ module so that the user can use the output smoothly.
+ output_name, output = create_var_in_main(output_name, output)
+ _LOGGER.info(
+ "The output PCollection variable is %s with element_type %s",
+ output_name,
+ pformat_namedtuple(output.element_type))
+ return output_name, output, chain
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception as e:
+ on_error('Error when applying the Beam SQL: %s', e)
+ else:
+ return output_name, chain.current, chain
def pcolls_from_streaming_cache(
@@ -304,19 +374,26 @@ def _generate_output_name(
def _build_query_components(
- query: str, found: Dict[str, beam.PCollection]
+ query: str,
+ found: Dict[str, beam.PCollection],
+ output_name: str,
+ run: bool = True
) -> Tuple[str,
- Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline]]:
+ Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline],
+ SqlChain]:
"""Builds necessary components needed to apply the SqlTransform.
Args:
query: The SQL query to be executed by the magic.
found: The PCollections with variable names found to be used by the query.
+ output_name: The output variable name in __main__ module.
+ run: Whether to prepare components for a local run or not.
Returns:
- The processed query to be executed by the magic and a source to apply the
+ The processed query to be executed by the magic; a source to apply the
SqlTransform to: a dictionary of tagged PCollections, or a single
- PCollection, or the pipeline to execute the query.
+ PCollection, or the pipeline to execute the query; the chain of applied
+ beam_sql magics this one belongs to.
"""
if found:
user_pipeline = ie.current_env().user_pipeline(
@@ -324,26 +401,38 @@ def _build_query_components(
sql_pipeline = beam.Pipeline(options=user_pipeline._options)
ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
sql_source = {}
- if has_source_to_cache(user_pipeline):
- sql_source = pcolls_from_streaming_cache(
- user_pipeline, sql_pipeline, found)
+ if run:
+ if has_source_to_cache(user_pipeline):
+ sql_source = pcolls_from_streaming_cache(
+ user_pipeline, sql_pipeline, found)
+ else:
+ cache_manager = ie.current_env().get_cache_manager(
+ user_pipeline, create_if_absent=True)
+ for pcoll_name, pcoll in found.items():
+ cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
+ sql_source[pcoll_name] = unreify_from_cache(
+ pipeline=sql_pipeline,
+ cache_key=cache_key,
+ cache_manager=cache_manager,
+ element_type=pcoll.element_type)
else:
- cache_manager = ie.current_env().get_cache_manager(
- user_pipeline, create_if_absent=True)
- for pcoll_name, pcoll in found.items():
- cache_key = CacheKey.from_pcoll(pcoll_name, pcoll).to_str()
- sql_source[pcoll_name] = unreify_from_cache(
- pipeline=sql_pipeline,
- cache_key=cache_key,
- cache_manager=cache_manager,
- element_type=pcoll.element_type)
+ sql_source = found
if len(sql_source) == 1:
query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
sql_source = next(iter(sql_source.values()))
- else:
+
+ node = SqlNode(
+ output_name=output_name, source=set(found.keys()), query=query)
+ chain = ie.current_env().get_sql_chain(
+ user_pipeline, set_user_pipeline=True).append(node)
+ else: # does not query any existing PCollection
sql_source = beam.Pipeline()
ie.current_env().add_user_pipeline(sql_source)
- return query, sql_source
+
+ # The node should be the root node of the chain created below.
+ node = SqlNode(output_name=output_name, source=sql_source, query=query)
+ chain = ie.current_env().get_sql_chain(sql_source).append(node)
+ return query, sql_source, chain
@progress_indicated
diff --git a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
index 538abbb..3d843a0 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/beam_sql_magics_test.py
@@ -59,9 +59,13 @@ class BeamSqlMagicsTest(unittest.TestCase):
query = """SELECT CAST(1 AS INT) AS `id`,
CAST('foo' AS VARCHAR) AS `str`,
CAST(3.14 AS DOUBLE) AS `flt`"""
- processed_query, sql_source = _build_query_components(query, {})
+ processed_query, sql_source, chain = _build_query_components(
+ query, {}, 'output')
self.assertEqual(processed_query, query)
self.assertIsInstance(sql_source, beam.Pipeline)
+ self.assertIsInstance(chain.current.source, beam.Pipeline)
+ self.assertEqual('output', chain.current.output_name)
+ self.assertEqual(query, chain.current.query)
def test_build_query_components_when_single_pcoll_queried(self):
p = beam.Pipeline()
@@ -76,10 +80,14 @@ class BeamSqlMagicsTest(unittest.TestCase):
cache_key,
cache_manager,
element_type: target):
- processed_query, sql_source = _build_query_components(query, found)
-
- self.assertEqual(processed_query, 'SELECT * FROM PCOLLECTION where a=1')
+ processed_query, sql_source, chain = _build_query_components(
+ query, found, 'output')
+ expected_query = 'SELECT * FROM PCOLLECTION where a=1'
+ self.assertEqual(expected_query, processed_query)
self.assertIsInstance(sql_source, beam.PCollection)
+ self.assertIn('target', chain.current.source)
+ self.assertEqual(expected_query, chain.current.query)
+ self.assertEqual('output', chain.current.output_name)
def test_build_query_components_when_multiple_pcolls_queried(self):
p = beam.Pipeline()
@@ -95,12 +103,17 @@ class BeamSqlMagicsTest(unittest.TestCase):
cache_key,
cache_manager,
element_type: pcoll_1):
- processed_query, sql_source = _build_query_components(query, found)
+ processed_query, sql_source, chain = _build_query_components(
+ query, found, 'output')
self.assertEqual(processed_query, query)
self.assertIsInstance(sql_source, dict)
self.assertIn('pcoll_1', sql_source)
self.assertIn('pcoll_2', sql_source)
+ self.assertIn('pcoll_1', chain.current.source)
+ self.assertIn('pcoll_2', chain.current.source)
+ self.assertEqual(query, chain.current.query)
+ self.assertEqual('output', chain.current.output_name)
def test_build_query_components_when_unbounded_pcolls_queried(self):
p = beam.Pipeline()
@@ -115,8 +128,11 @@ class BeamSqlMagicsTest(unittest.TestCase):
lambda a,
b,
c: found):
- _, sql_source = _build_query_components(query, found)
+ _, sql_source, chain = _build_query_components(query, found, 'output')
self.assertIs(sql_source, pcoll)
+ self.assertIn('pcoll', chain.current.source)
+ self.assertEqual('SELECT * FROM PCOLLECTION', chain.current.query)
+ self.assertEqual('output', chain.current.output_name)
def test_cache_output(self):
p_cache_output = beam.Pipeline()
diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py
new file mode 100644
index 0000000..a6f4866
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain.py
@@ -0,0 +1,226 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Module for tracking a chain of beam_sql magics applied.
+
+For internal use only; no backwards-compatibility guarantees.
+"""
+
+# pytype: skip-file
+
+import importlib
+import logging
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import Set
+from typing import Union
+
+import apache_beam as beam
+from apache_beam.internal import pickler
+from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
+from apache_beam.runners.interactive.utils import create_var_in_main
+from apache_beam.runners.interactive.utils import pcoll_by_name
+from apache_beam.runners.interactive.utils import progress_indicated
+from apache_beam.transforms.sql import SqlTransform
+from apache_beam.utils.interactive_utils import is_in_ipython
+
+_LOGGER = logging.getLogger(__name__)
+
+
+@dataclass
+class SqlNode:
+ """Each SqlNode represents a beam_sql magic applied.
+
+ Attributes:
+ output_name: the watched unique name of the beam_sql output. Can be used as
+ an identifier.
+ source: the inputs consumed by this node. Can be a pipeline or a set of
+ PCollections represented by their variable names watched. When it's a
+ pipeline, the node computes from raw values in the query, so the output
+ can be consumed by any SqlNode in any SqlChain.
+ query: the SQL query applied by this node.
+ schemas: the schemas (NamedTuple classes) used by this node.
+ evaluated: the pipelines this node has been evaluated for.
+ next: the next SqlNode applied chronologically.
+ execution_count: the execution count if in an IPython env.
+ """
+ output_name: str
+ source: Union[beam.Pipeline, Set[str]]
+ query: str
+ schemas: Set[Any] = None
+ evaluated: Set[beam.Pipeline] = None
+ next: Optional['SqlNode'] = None
+ execution_count: int = 0
+
+ def __post_init__(self):
+ if not self.schemas:
+ self.schemas = set()
+ if not self.evaluated:
+ self.evaluated = set()
+ if is_in_ipython():
+ from IPython import get_ipython
+ self.execution_count = get_ipython().execution_count
+
+ def __hash__(self):
+ return hash(
+ (self.output_name, self.source, self.query, self.execution_count))
+
+ def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline:
+ """Converts the chain into an executable pipeline."""
+ if pipeline not in self.evaluated:
+ # The whole chain should form a single pipeline.
+ source = self.source
+ if isinstance(self.source, beam.Pipeline):
+ if pipeline: # use the known pipeline
+ source = pipeline
+ else: # use the source pipeline
+ pipeline = self.source
+ else:
+ name_to_pcoll = pcoll_by_name()
+ if len(self.source) == 1:
+ source = name_to_pcoll.get(next(iter(self.source)))
+ else:
+ source = {s: name_to_pcoll.get(s) for s in self.source}
+ if isinstance(source, beam.Pipeline):
+ output = source | 'beam_sql_{}_{}'.format(
+ self.output_name, self.execution_count) >> SqlTransform(self.query)
+ else:
+ output = source | 'schema_loaded_beam_sql_{}_{}'.format(
+ self.output_name, self.execution_count
+ ) >> SchemaLoadedSqlTransform(
+ self.output_name, self.query, self.schemas, self.execution_count)
+ _ = create_var_in_main(self.output_name, output)
+ self.evaluated.add(pipeline)
+ if self.next:
+ return self.next.to_pipeline(pipeline)
+ else:
+ return pipeline
+
+
+class SchemaLoadedSqlTransform(beam.PTransform):
+ """PTransform that loads schema before executing SQL.
+
+ When submitting a pipeline to remote runner for execution, schemas defined in
+ the main module are not available without save_main_session. However,
+ save_main_session might fail when there is anything unpicklable. This DoFn
+ makes sure only the schemas needed are pickled locally and restored later on
+ workers.
+ """
+ def __init__(self, output_name, query, schemas, execution_count):
+ self.output_name = output_name
+ self.query = query
+ self.schemas = schemas
+ self.execution_count = execution_count
+ # TODO(BEAM-8123): clean up this attribute or the whole wrapper PTransform.
+ # Dill does not preserve everything. On the other hand, save_main_session
+ # is not stable. Until cloudpickle replaces dill in Beam, we work around
+ # it by explicitly pickling annotations and load schemas in remote main
+ # sessions.
+ self.schema_annotations = [s.__annotations__ for s in self.schemas]
+
+ class _SqlTransformDoFn(beam.DoFn):
+ """The DoFn yields all its input without any transform but a setup to
+ configure the main session."""
+ def __init__(self, schemas, annotations):
+ self.pickled_schemas = [pickler.dumps(s) for s in schemas]
+ self.pickled_annotations = [pickler.dumps(a) for a in annotations]
+
+ def setup(self):
+ main_session = importlib.import_module('__main__')
+ for pickled_schema, pickled_annotation in zip(
+ self.pickled_schemas, self.pickled_annotations):
+ schema = pickler.loads(pickled_schema)
+ schema.__annotations__ = pickler.loads(pickled_annotation)
+ if not hasattr(main_session, schema.__name__) or not hasattr(
+ getattr(main_session, schema.__name__), '__annotations__'):
+ # Restore the schema in the main session on the [remote] worker.
+ setattr(main_session, schema.__name__, schema)
+ register_coder_for_schema(schema)
+
+ def process(self, e):
+ yield e
+
+ def expand(self, source):
+ """Applies the SQL transform. If a PCollection uses a schema defined in
+ the main session, use the additional DoFn to restore it on the worker."""
+ if isinstance(source, dict):
+ schema_loaded = {
+ tag: pcoll | 'load_schemas_{}_tag_{}_{}'.format(
+ self.output_name, tag, self.execution_count) >> beam.ParDo(
+ self._SqlTransformDoFn(self.schemas, self.schema_annotations))
+ if pcoll.element_type in self.schemas else pcoll
+ for tag,
+ pcoll in source.items()
+ }
+ elif isinstance(source, beam.pvalue.PCollection):
+ schema_loaded = source | 'load_schemas_{}_{}'.format(
+ self.output_name, self.execution_count) >> beam.ParDo(
+ self._SqlTransformDoFn(self.schemas, self.schema_annotations)
+ ) if source.element_type in self.schemas else source
+ else:
+ raise ValueError(
+ '{} should be either a single PCollection or a dict of named '
+ 'PCollections.'.format(source))
+ return schema_loaded | 'beam_sql_{}_{}'.format(
+ self.output_name, self.execution_count) >> SqlTransform(self.query)
+
+
+@dataclass
+class SqlChain:
+ """A chain of SqlNodes.
+
+ Attributes:
+ nodes: all nodes by their output_names.
+ root: the first SqlNode applied chronologically.
+ current: the last node applied.
+ user_pipeline: the user defined pipeline this chain originates from. If
+ None, the whole chain just computes from raw values in queries.
+ Otherwise, at least some of the nodes in chain has queried against
+ PCollections.
+ """
+ nodes: Dict[str, SqlNode] = None
+ root: Optional[SqlNode] = None
+ current: Optional[SqlNode] = None
+ user_pipeline: Optional[beam.Pipeline] = None
+
+ def __post_init__(self):
+ if not self.nodes:
+ self.nodes = {}
+
+ @progress_indicated
+ def to_pipeline(self) -> beam.Pipeline:
+ """Converts the chain into a beam pipeline."""
+ pipeline_to_execute = self.root.to_pipeline(self.user_pipeline)
+ # The pipeline definitely contains external transform: SqlTransform.
+ pipeline_to_execute.contains_external_transforms = True
+ return pipeline_to_execute
+
+ def append(self, node: SqlNode) -> 'SqlChain':
+ """Appends a node to the chain."""
+ if self.current:
+ self.current.next = node
+ else:
+ self.root = node
+ self.current = node
+ self.nodes[node.output_name] = node
+ return self
+
+ def get(self, output_name: str) -> Optional[SqlNode]:
+ """Gets a node from the chain based on the given output_name."""
+ return self.nodes.get(output_name, None)
diff --git a/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py
new file mode 100644
index 0000000..42d0804
--- /dev/null
+++ b/sdks/python/apache_beam/runners/interactive/sql/sql_chain_test.py
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for sql_chain module."""
+
+# pytype: skip-file
+
+import unittest
+from unittest.mock import patch
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.runners.interactive import interactive_environment as ie
+from apache_beam.runners.interactive.sql.sql_chain import SqlChain
+from apache_beam.runners.interactive.sql.sql_chain import SqlNode
+from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython
+
+
+class SqlChainTest(unittest.TestCase):
+ def test_init(self):
+ chain = SqlChain()
+ self.assertEqual({}, chain.nodes)
+ self.assertIsNone(chain.root)
+ self.assertIsNone(chain.current)
+ self.assertIsNone(chain.user_pipeline)
+
+ def test_append_first_node(self):
+ node = SqlNode(output_name='first', source='a', query='q1')
+ chain = SqlChain().append(node)
+ self.assertIs(node, chain.get(node.output_name))
+ self.assertIs(node, chain.root)
+ self.assertIs(node, chain.current)
+
+ def test_append_non_root_node(self):
+ chain = SqlChain().append(
+ SqlNode(output_name='root', source='root', query='q1'))
+ self.assertIsNone(chain.root.next)
+ node = SqlNode(output_name='next_node', source='root', query='q2')
+ chain.append(node)
+ self.assertIs(node, chain.root.next)
+ self.assertIs(node, chain.get(node.output_name))
+
+ @patch(
+ 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
+ '__rrshift__')
+ def test_to_pipeline_only_evaluate_once_per_pipeline_and_node(
+ self, mocked_sql_transform):
+ p = beam.Pipeline()
+ ie.current_env().watch({'p': p})
+ pcoll_1 = p | 'create pcoll_1' >> beam.Create([1, 2, 3])
+ pcoll_2 = p | 'create pcoll_2' >> beam.Create([4, 5, 6])
+ ie.current_env().watch({'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2})
+ node = SqlNode(
+ output_name='root', source={'pcoll_1', 'pcoll_2'}, query='q1')
+ chain = SqlChain(user_pipeline=p).append(node)
+ _ = chain.to_pipeline()
+ mocked_sql_transform.assert_called_once()
+ _ = chain.to_pipeline()
+ mocked_sql_transform.assert_called_once()
+
+ @unittest.skipIf(
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+ @pytest.mark.skipif(
+ not ie.current_env().is_interactive_ready,
+ reason='[interactive] dependency is not installed.')
+ @patch(
+ 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
+ '__rrshift__')
+ def test_nodes_with_same_outputs(self, mocked_sql_transform):
+ p = beam.Pipeline()
+ ie.current_env().watch({'p_nodes_with_same_output': p})
+ pcoll = p | 'create pcoll' >> beam.Create([1, 2, 3])
+ ie.current_env().watch({'pcoll': pcoll})
+ chain = SqlChain(user_pipeline=p)
+ output_name = 'output'
+
+ with patch('IPython.get_ipython', new_callable=mock_get_ipython) as cell:
+ with cell:
+ node_cell_1 = SqlNode(output_name, source='pcoll', query='q1')
+ chain.append(node_cell_1)
+ _ = chain.to_pipeline()
+ mocked_sql_transform.assert_called_with(
+ 'schema_loaded_beam_sql_output_1')
+ with cell:
+ node_cell_2 = SqlNode(output_name, source='pcoll', query='q2')
+ chain.append(node_cell_2)
+ _ = chain.to_pipeline()
+ mocked_sql_transform.assert_called_with(
+ 'schema_loaded_beam_sql_output_2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils.py b/sdks/python/apache_beam/runners/interactive/sql/utils.py
index fb4e57d..b2e75c8 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils.py
@@ -23,29 +23,39 @@ For internal use only; no backward-compatibility guarantees.
# pytype: skip-file
import logging
+import os
+import tempfile
+from dataclasses import dataclass
+from typing import Any
+from typing import Callable
from typing import Dict
from typing import NamedTuple
+from typing import Optional
+from typing import Type
+from typing import Union
import apache_beam as beam
-from apache_beam.runners.interactive import interactive_beam as ib
+from apache_beam.io import WriteToText
+from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.options.pipeline_options import WorkerOptions
+from apache_beam.runners.interactive.utils import create_var_in_main
+from apache_beam.runners.interactive.utils import progress_indicated
+from apache_beam.runners.runner import create_runner
+from apache_beam.typehints.native_type_compatibility import match_is_named_tuple
+from apache_beam.utils.interactive_utils import is_in_ipython
_LOGGER = logging.getLogger(__name__)
-def is_namedtuple(cls: type) -> bool:
- """Determines if a class is built from typing.NamedTuple."""
- return (
- isinstance(cls, type) and issubclass(cls, tuple) and
- hasattr(cls, '_fields') and hasattr(cls, '__annotations__'))
-
-
def register_coder_for_schema(
schema: NamedTuple, verbose: bool = False) -> None:
"""Registers a RowCoder for the given schema if hasn't.
Notifies the user of what code has been implicitly executed.
"""
- assert is_namedtuple(schema), (
+ assert match_is_named_tuple(schema), (
'Schema %s is not a typing.NamedTuple.' % schema)
coder = beam.coders.registry.get_coder(schema)
if not isinstance(coder, beam.coders.RowCoder):
@@ -77,21 +87,6 @@ def find_pcolls(
if verbose:
_LOGGER.info('Found PCollections used in the magic: %s.', found)
_LOGGER.info('Collecting data...')
- for name, pcoll in found.items():
- try:
- _ = ib.collect(pcoll)
- except (KeyboardInterrupt, SystemExit):
- raise
- except:
- _LOGGER.error(
- 'Cannot collect data for PCollection %s. Please make sure the '
- 'PCollections queried in the sql "%s" are all from a single '
- 'pipeline using an InteractiveRunner. Make sure there is no '
- 'ambiguity, for example, same named PCollections from multiple '
- 'pipelines or notebook re-executions.',
- name,
- sql)
- raise
return found
@@ -123,3 +118,314 @@ def pformat_namedtuple(schema: NamedTuple) -> str:
'{}: {}'.format(k, v.__name__) for k,
v in schema.__annotations__.items()
]))
+
+
+def pformat_dict(raw_input: Dict[str, Any]) -> str:
+ return '{{\n{}\n}}'.format(
+ ',\n'.join(['{}: {}'.format(k, v) for k, v in raw_input.items()]))
+
+
+@dataclass
+class OptionsEntry:
+ """An entry of PipelineOptions that can be visualized through ipywidgets to
+ take inputs in IPython notebooks interactively.
+
+ Attributes:
+ label: The value of the Label widget.
+ help: The help message of the entry, usually the same to the help in
+ PipelineOptions.
+ cls: The PipelineOptions class/subclass the options belong to.
+ arg_builder: Builds the argument/option. If it's a str, this entry
+ assigns the input ipywidget's value directly to the argument. If it's a
+ Dict, use the corresponding Callable to assign the input value to each
+ argument. If Callable is None, fallback to assign the input value
+ directly. This allows building multiple similar PipelineOptions
+ arguments from a single input, such as staging_location and
+ temp_location in GoogleCloudOptions.
+ default: The default value of the entry, None if absent.
+ """
+ label: str
+ help: str
+ cls: Type[PipelineOptions]
+ arg_builder: Union[str, Dict[str, Optional[Callable]]]
+ default: Optional[str] = None
+
+ def __post_init__(self):
+ # The attribute holds an ipywidget, currently only supports Text.
+ # The str value can be accessed by self.input.value.
+ self.input = None
+
+
+class OptionsForm:
+ """A form visualized to take inputs from users in IPython Notebooks and
+ generate PipelineOptions to run pipelines.
+ """
+ def __init__(self):
+ self.options = PipelineOptions()
+ self.entries = []
+
+ def add(self, entry: OptionsEntry) -> 'OptionsForm':
+ """Adds an OptionsEntry to the form.
+ """
+ self.entries.append(entry)
+ return self
+
+ def to_options(self) -> PipelineOptions:
+ """Builds the PipelineOptions based on user inputs.
+
+ Can only be invoked after display_for_input.
+ """
+ for entry in self.entries:
+ assert entry.input, (
+ 'to_options invoked before display_for_input. '
+ 'Wrong usage.')
+ view = self.options.view_as(entry.cls)
+ if isinstance(entry.arg_builder, str):
+ setattr(view, entry.arg_builder, entry.input.value)
+ else:
+ for arg, builder in entry.arg_builder.items():
+ if builder:
+ setattr(view, arg, builder(entry.input.value))
+ else:
+ setattr(view, arg, entry.input.value)
+ self.additional_options()
+ return self.options
+
+ def additional_options(self):
+ """Alters the self.options with additional config."""
+ pass
+
+ def display_for_input(self) -> 'OptionsForm':
+ """Displays the widgets to take user inputs."""
+ from IPython.display import display
+ from ipywidgets import GridBox
+ from ipywidgets import Label
+ from ipywidgets import Layout
+ from ipywidgets import Text
+ widgets = []
+ for entry in self.entries:
+ text_label = Label(value=entry.label)
+ text_input = entry.input if entry.input else Text(
+ value=entry.default if entry.default else '')
+ text_help = Label(value=entry.help)
+ entry.input = text_input
+ widgets.append(text_label)
+ widgets.append(text_input)
+ widgets.append(text_help)
+ grid = GridBox(widgets, layout=Layout(grid_template_columns='1fr 2fr 6fr'))
+ display(grid)
+ self.display_actions()
+ return self
+
+ def display_actions(self):
+ """Displays actionable widgets to utilize the options, run pipelines and
+ etc."""
+ pass
+
+
+class DataflowOptionsForm(OptionsForm):
+ """A form to take inputs from users in IPython Notebooks to build
+ PipelineOptions to run pipelines on Dataflow.
+
+ Only contains minimum fields needed.
+ """
+ @staticmethod
+ def _build_default_project() -> str:
+ """Builds a default project id."""
+ try:
+ # pylint: disable=c-extension-no-member
+ import google.auth
+ return google.auth.default()[1]
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except Exception as e:
+ _LOGGER.warning('There is some issue with your gcloud auth: %s', e)
+ return 'your-project-id'
+
+ @staticmethod
+ def _build_req_file_from_pkgs(pkgs) -> Optional[str]:
+ """Builds a requirements file that contains all additional PYPI packages
+ needed."""
+ if pkgs:
+ deps = pkgs.split(',')
+ req_file = os.path.join(
+ tempfile.mkdtemp(prefix='beam-sql-dataflow-'), 'req.txt')
+ with open(req_file, 'a') as f:
+ for dep in deps:
+ f.write(dep.strip() + '\n')
+ return req_file
+ return None
+
+ def __init__(
+ self,
+ output_name: str,
+ output_pcoll: beam.PCollection,
+ verbose: bool = False):
+ """Inits the OptionsForm for setting up Dataflow jobs."""
+ super().__init__()
+ self.p = output_pcoll.pipeline
+ self.output_name = output_name
+ self.output_pcoll = output_pcoll
+ self.verbose = verbose
+ self.notice_shown = False
+ self.add(
+ OptionsEntry(
+ label='Project Id',
+ help='Name of the Cloud project owning the Dataflow job.',
+ cls=GoogleCloudOptions,
+ arg_builder='project',
+ default=DataflowOptionsForm._build_default_project())
+ ).add(
+ OptionsEntry(
+ label='Region',
+ help='The Google Compute Engine region for creating Dataflow job.',
+ cls=GoogleCloudOptions,
+ arg_builder='region',
+ default='us-central1')
+ ).add(
+ OptionsEntry(
+ label='GCS Bucket',
+ help=(
+ 'GCS path to stage code packages needed by workers and save '
+ 'temporary workflow jobs.'),
+ cls=GoogleCloudOptions,
+ arg_builder={
+ 'staging_location': lambda x: x + '/staging',
+ 'temp_location': lambda x: x + '/temp'
+ },
+ default='gs://YOUR_GCS_BUCKET_HERE')
+ ).add(
+ OptionsEntry(
+ label='Additional Packages',
+ help=(
+ 'PYPI packages installed, comma-separated. If None, leave '
+ 'this field empty.'),
+ cls=SetupOptions,
+ arg_builder={
+ 'requirements_file': lambda x: DataflowOptionsForm.
+ _build_req_file_from_pkgs(x)
+ },
+ default=''))
+
+ def additional_options(self):
+ # Use the latest Java SDK by default.
+ sdk_overrides = self.options.view_as(
+ WorkerOptions).sdk_harness_container_image_overrides
+ override = '.*java.*,apache/beam_java11_sdk:latest'
+ if sdk_overrides and override not in sdk_overrides:
+ sdk_overrides.append(override)
+ else:
+ self.options.view_as(
+ WorkerOptions).sdk_harness_container_image_overrides = [override]
+
+ def display_actions(self):
+ from IPython.display import HTML
+ from IPython.display import display
+ from ipywidgets import Button
+ from ipywidgets import GridBox
+ from ipywidgets import Layout
+ from ipywidgets import Output
+ options_output_area = Output()
+ run_output_area = Output()
+ run_btn = Button(
+ description='Run on Dataflow',
+ button_style='success',
+ tooltip=(
+ 'Submit to Dataflow for execution with the configured options. The '
+ 'output PCollection\'s data will be written to the GCS bucket you '
+ 'configure.'))
+ show_options_btn = Button(
+ description='Show Options',
+ button_style='info',
+ tooltip='Show current pipeline options configured.')
+
+ def _run_on_dataflow(btn):
+ with run_output_area:
+ run_output_area.clear_output()
+
+ @progress_indicated
+ def _inner():
+ options = self.to_options()
+ # Caches the output_pcoll to a GCS bucket.
+ try:
+ execution_count = 0
+ if is_in_ipython():
+ from IPython import get_ipython
+ execution_count = get_ipython().execution_count
+ output_location = '{}/{}'.format(
+ options.view_as(GoogleCloudOptions).staging_location,
+ self.output_name)
+ _ = self.output_pcoll | 'WriteOuput{}_{}ToGCS'.format(
+ self.output_name,
+ execution_count) >> WriteToText(output_location)
+ _LOGGER.info(
+ 'Data of output PCollection %s will be written to %s',
+ self.output_name,
+ output_location)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except: # pylint: disable=bare-except
+ # The transform has been added before, noop.
+ pass
+ if self.verbose:
+ _LOGGER.info(
+ 'Running the pipeline on Dataflow with pipeline options %s.',
+ pformat_dict(options.display_data()))
+ result = create_runner('DataflowRunner').run_pipeline(self.p, options)
+ cloud_options = options.view_as(GoogleCloudOptions)
+ url = (
+ 'https://console.cloud.google.com/dataflow/jobs/%s/%s?project=%s'
+ % (cloud_options.region, result.job_id(), cloud_options.project))
+ display(
+ HTML(
+ 'Click <a href="%s" target="_new">here</a> for the details '
+ 'of your Dataflow job.' % url))
+ result_name = 'result_{}'.format(self.output_name)
+ create_var_in_main(result_name, result)
+ if self.verbose:
+ _LOGGER.info(
+ 'The pipeline result of the run can be accessed from variable '
+ '%s. The current status is %s.',
+ result_name,
+ result)
+
+ try:
+ btn.disabled = True
+ _inner()
+ finally:
+ btn.disabled = False
+
+ run_btn.on_click(_run_on_dataflow)
+
+ def _show_options(btn):
+ with options_output_area:
+ options_output_area.clear_output()
+ options = self.to_options()
+ options_name = 'options_{}'.format(self.output_name)
+ create_var_in_main(options_name, options)
+ _LOGGER.info(
+ 'The pipeline options configured is: %s.',
+ pformat_dict(options.display_data()))
+
+ show_options_btn.on_click(_show_options)
+ grid = GridBox([run_btn, show_options_btn],
+ layout=Layout(grid_template_columns='repeat(2, 200px)'))
+ display(grid)
+
+ # Implicitly initializes the options variable before 1st time showing
+ # options.
+ options_name_inited, _ = create_var_in_main('options_{}'.format(
+ self.output_name), self.to_options())
+ if not self.notice_shown:
+ _LOGGER.info(
+ 'The pipeline options can be configured through variable %s. You '
+ 'may also add additional options or sink transforms such as write '
+ 'to BigQuery in other notebook cells. Come back to click "Run on '
+ 'Dataflow" button once you complete additional configurations. '
+ 'Optionally, you can chain more beam_sql magics with DataflowRunner '
+ 'and click "Run on Dataflow" in their outputs.',
+ options_name_inited)
+ self.notice_shown = True
+
+ display(options_output_area)
+ display(run_output_area)
diff --git a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
index 01a54c3..16d03f5 100644
--- a/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
+++ b/sdks/python/apache_beam/runners/interactive/sql/utils_test.py
@@ -23,9 +23,15 @@ import unittest
from typing import NamedTuple
from unittest.mock import patch
+import pytest
+
import apache_beam as beam
+from apache_beam.options.pipeline_options import GoogleCloudOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.interactive import interactive_environment as ie
+from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
from apache_beam.runners.interactive.sql.utils import find_pcolls
-from apache_beam.runners.interactive.sql.utils import is_namedtuple
+from apache_beam.runners.interactive.sql.utils import pformat_dict
from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
@@ -37,19 +43,6 @@ class ANamedTuple(NamedTuple):
class UtilsTest(unittest.TestCase):
- def test_is_namedtuple(self):
- class AType:
- pass
-
- a_type = AType
- a_tuple = type((1, 2, 3))
-
- a_namedtuple = ANamedTuple
-
- self.assertTrue(is_namedtuple(a_namedtuple))
- self.assertFalse(is_namedtuple(a_type))
- self.assertFalse(is_namedtuple(a_tuple))
-
def test_register_coder_for_schema(self):
self.assertNotIsInstance(
beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
@@ -80,6 +73,35 @@ class UtilsTest(unittest.TestCase):
self.assertEqual(
'ANamedTuple(a: int, b: str)', pformat_namedtuple(ANamedTuple))
+ def test_pformat_dict(self):
+ self.assertEqual('{\na: 1,\nb: 2\n}', pformat_dict({'a': 1, 'b': '2'}))
+
+
+@unittest.skipIf(
+ not ie.current_env().is_interactive_ready,
+ '[interactive] dependency is not installed.')
+@pytest.mark.skipif(
+ not ie.current_env().is_interactive_ready,
+ reason='[interactive] dependency is not installed.')
+class OptionsFormTest(unittest.TestCase):
+ def test_dataflow_options_form(self):
+ p = beam.Pipeline()
+ pcoll = p | beam.Create([1, 2, 3])
+ with patch('google.auth') as ga:
+ ga.default = lambda: ['', 'default_project_id']
+ df_form = DataflowOptionsForm('pcoll', pcoll)
+ df_form.display_for_input()
+ df_form.entries[2].input.value = 'gs://test-bucket'
+ df_form.entries[3].input.value = 'a-pkg'
+ options = df_form.to_options()
+ cloud_options = options.view_as(GoogleCloudOptions)
+ self.assertEqual(cloud_options.project, 'default_project_id')
+ self.assertEqual(cloud_options.region, 'us-central1')
+ self.assertEqual(
+ cloud_options.staging_location, 'gs://test-bucket/staging')
+ self.assertEqual(cloud_options.temp_location, 'gs://test-bucket/temp')
+ self.assertIsNotNone(options.view_as(SetupOptions).requirements_file)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/utils.py b/sdks/python/apache_beam/runners/interactive/utils.py
index 49b87ba..957b65c 100644
--- a/sdks/python/apache_beam/runners/interactive/utils.py
+++ b/sdks/python/apache_beam/runners/interactive/utils.py
@@ -20,9 +20,12 @@
import functools
import hashlib
+import importlib
import json
import logging
+from typing import Any
from typing import Dict
+from typing import Tuple
import pandas as pd
@@ -405,3 +408,22 @@ def unbounded_sources(pipeline):
v = CheckUnboundednessVisitor()
pipeline.visit(v)
return v.unbounded_sources
+
+
+def create_var_in_main(name: str,
+ value: Any,
+ watch: bool = True) -> Tuple[str, Any]:
+ """Declares a variable in the main module.
+
+ Args:
+ name: the variable name in the main module.
+ value: the value of the variable.
+ watch: whether to watch it in the interactive environment.
+ Returns:
+ A 2-entry tuple of the variable name and value.
+ """
+ setattr(importlib.import_module('__main__'), name, value)
+ if watch:
+ from apache_beam.runners.interactive import interactive_environment as ie
+ ie.current_env().watch({name: value})
+ return name, value
diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py b/sdks/python/apache_beam/runners/interactive/utils_test.py
index 784081e..0915ff2 100644
--- a/sdks/python/apache_beam/runners/interactive/utils_test.py
+++ b/sdks/python/apache_beam/runners/interactive/utils_test.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import importlib
import json
import logging
import tempfile
@@ -318,6 +319,13 @@ class GeneralUtilTest(unittest.TestCase):
})
self.assertEqual('pcoll_test_find_pcoll_name', utils.find_pcoll_name(pcoll))
+ def test_create_var_in_main(self):
+ name = 'test_create_var_in_main'
+ value = Record(0, 0, 0)
+ _ = utils.create_var_in_main(name, value)
+ main_session = importlib.import_module('__main__')
+ self.assertIs(getattr(main_session, name, None), value)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 514f4e7..d16618d 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -206,6 +206,7 @@ INTERACTIVE_BEAM = [
'facets-overview>=1.0.0,<2',
'ipython>=7,<8',
'ipykernel>=5.2.0,<6',
+ 'ipywidgets>=7.6.5,<8',
# Skip version 6.1.13 due to
# https://github.com/jupyter/jupyter_client/issues/637
'jupyter-client>=6.1.11,<6.1.13',