You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/19 21:41:55 UTC

[GitHub] kaxil closed pull request #4314: [AIRFLOW-3398] Google Cloud Spanner instance database query operator

kaxil closed pull request #4314: [AIRFLOW-3398] Google Cloud Spanner instance database query operator
URL: https://github.com/apache/incubator-airflow/pull/4314
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/example_dags/example_gcp_spanner.py b/airflow/contrib/example_dags/example_gcp_spanner.py
index dd8b8c52b9..cec3dcb855 100644
--- a/airflow/contrib/example_dags/example_gcp_spanner.py
+++ b/airflow/contrib/example_dags/example_gcp_spanner.py
@@ -18,18 +18,18 @@
 # under the License.
 
 """
-Example Airflow DAG that creates, updates and deletes a Cloud Spanner instance.
+Example Airflow DAG that creates, updates, queries and deletes a Cloud Spanner instance.
 
 This DAG relies on the following environment variables
-* PROJECT_ID - Google Cloud Platform project for the Cloud Spanner instance.
-* INSTANCE_ID - Cloud Spanner instance ID.
-* CONFIG_NAME - The name of the instance's configuration. Values are of the form
+* SPANNER_PROJECT_ID - Google Cloud Platform project for the Cloud Spanner instance.
+* SPANNER_INSTANCE_ID - Cloud Spanner instance ID.
+* SPANNER_CONFIG_NAME - The name of the instance's configuration. Values are of the form
     projects/<project>/instanceConfigs/<configuration>.
     See also:
         https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs#InstanceConfig
         https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list#google.spanner.admin.instance.v1.InstanceAdmin.ListInstanceConfigs
-* NODE_COUNT - Number of nodes allocated to the instance.
-* DISPLAY_NAME - The descriptive name for this instance as it appears in UIs.
+* SPANNER_NODE_COUNT - Number of nodes allocated to the instance.
+* SPANNER_DISPLAY_NAME - The descriptive name for this instance as it appears in UIs.
     Must be unique per project and between 4 and 30 characters in length.
 """
 
@@ -38,15 +38,17 @@
 import airflow
 from airflow import models
 from airflow.contrib.operators.gcp_spanner_operator import \
-    CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator
+    CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDatabaseQueryOperator, \
+    CloudSpannerInstanceDeleteOperator
 
 # [START howto_operator_spanner_arguments]
-PROJECT_ID = os.environ.get('PROJECT_ID', 'example-project')
-INSTANCE_ID = os.environ.get('INSTANCE_ID', 'testinstance')
-CONFIG_NAME = os.environ.get('CONFIG_NAME',
+PROJECT_ID = os.environ.get('SPANNER_PROJECT_ID', 'example-project')
+INSTANCE_ID = os.environ.get('SPANNER_INSTANCE_ID', 'testinstance')
+DB_ID = os.environ.get('SPANNER_DB_ID', 'db1')
+CONFIG_NAME = os.environ.get('SPANNER_CONFIG_NAME',
                              'projects/example-project/instanceConfigs/eur3')
-NODE_COUNT = os.environ.get('NODE_COUNT', '1')
-DISPLAY_NAME = os.environ.get('DISPLAY_NAME', 'Test Instance')
+NODE_COUNT = os.environ.get('SPANNER_NODE_COUNT', '1')
+DISPLAY_NAME = os.environ.get('SPANNER_DISPLAY_NAME', 'Test Instance')
 # [END howto_operator_spanner_arguments]
 
 default_args = {
@@ -80,6 +82,24 @@
         task_id='spanner_instance_update_task'
     )
 
+    # [START howto_operator_spanner_query]
+    spanner_instance_query = CloudSpannerInstanceDatabaseQueryOperator(
+        project_id=PROJECT_ID,
+        instance_id=INSTANCE_ID,
+        database_id='db1',
+        query="DELETE FROM my_table2 WHERE true",
+        task_id='spanner_instance_query'
+    )
+    # [END howto_operator_spanner_query]
+
+    spanner_instance_query2 = CloudSpannerInstanceDatabaseQueryOperator(
+        project_id=PROJECT_ID,
+        instance_id=INSTANCE_ID,
+        database_id='db1',
+        query="example_gcp_spanner.sql",
+        task_id='spanner_instance_query2'
+    )
+
     # [START howto_operator_spanner_delete]
     spanner_instance_delete_task = CloudSpannerInstanceDeleteOperator(
         project_id=PROJECT_ID,
@@ -89,4 +109,5 @@
     # [END howto_operator_spanner_delete]
 
     spanner_instance_create_task >> spanner_instance_update_task \
+        >> spanner_instance_query >> spanner_instance_query2 \
         >> spanner_instance_delete_task
diff --git a/airflow/contrib/example_dags/example_gcp_spanner.sql b/airflow/contrib/example_dags/example_gcp_spanner.sql
new file mode 100644
index 0000000000..5d5f238022
--- /dev/null
+++ b/airflow/contrib/example_dags/example_gcp_spanner.sql
@@ -0,0 +1,3 @@
+INSERT my_table2 (id, name) VALUES (7, 'Seven');
+INSERT my_table2 (id, name)
+    VALUES (8, 'Eight');
diff --git a/airflow/contrib/hooks/gcp_spanner_hook.py b/airflow/contrib/hooks/gcp_spanner_hook.py
index fc73562e8b..96e8bcb71c 100644
--- a/airflow/contrib/hooks/gcp_spanner_hook.py
+++ b/airflow/contrib/hooks/gcp_spanner_hook.py
@@ -16,12 +16,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from google.longrunning.operations_grpc_pb2 import Operation  # noqa: F401
-from typing import Optional, Callable  # noqa: F401
-
 from google.api_core.exceptions import GoogleAPICallError
 from google.cloud.spanner_v1.client import Client
+from google.cloud.spanner_v1.database import Database
 from google.cloud.spanner_v1.instance import Instance  # noqa: F401
+from google.longrunning.operations_grpc_pb2 import Operation  # noqa: F401
+from typing import Optional, Callable  # noqa: F401
 
 from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
 
@@ -181,3 +181,28 @@ def delete_instance(self, project_id, instance_id):
         except GoogleAPICallError as e:
             self.log.error('An error occurred: %s. Aborting.', e.message)
             raise e
+
+    def execute_dml(self, project_id, instance_id, database_id, queries):
+        # type: (str, str, str, str) -> None
+        """
+        Executes an arbitrary DML query (INSERT, UPDATE, DELETE).
+
+        :param project_id: The ID of the project which owns the instances, tables and data.
+        :type project_id: str
+        :param instance_id: The ID of the instance.
+        :type instance_id: str
+        :param database_id: The ID of the database.
+        :type database_id: str
+        :param queries: The queries to be executed.
+        :type queries: str
+        """
+        client = self.get_client(project_id)
+        instance = client.instance(instance_id)
+        database = Database(database_id, instance)
+        database.run_in_transaction(lambda transaction:
+                                    self._execute_sql_in_transaction(transaction, queries))
+
+    @staticmethod
+    def _execute_sql_in_transaction(transaction, queries):
+        for sql in queries:
+            transaction.execute_update(sql)
diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py
index 7b329a3849..b803fcc30a 100644
--- a/airflow/contrib/operators/gcp_spanner_operator.py
+++ b/airflow/contrib/operators/gcp_spanner_operator.py
@@ -16,6 +16,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import six
+
 from airflow import AirflowException
 from airflow.contrib.hooks.gcp_spanner_hook import CloudSpannerHook
 from airflow.models import BaseOperator
@@ -130,3 +132,68 @@ def execute(self, context):
             self.log.info("Instance '%s' does not exist in project '%s'. "
                           "Aborting delete.", self.instance_id, self.project_id)
             return True
+
+
+class CloudSpannerInstanceDatabaseQueryOperator(BaseOperator):
+    """
+    Executes an arbitrary DML query (INSERT, UPDATE, DELETE).
+
+    :param project_id: The ID of the project which owns the instances, tables and data.
+    :type project_id: str
+    :param instance_id: The ID of the instance.
+    :type instance_id: str
+    :param database_id: The ID of the database.
+    :type database_id: str
+    :param query: The query or list of queries to be executed. Can be a path to a SQL file.
+    :type query: str or list
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
+    :type gcp_conn_id: str
+    """
+    # [START gcp_spanner_query_template_fields]
+    template_fields = ('project_id', 'instance_id', 'database_id', 'query', 'gcp_conn_id')
+    template_ext = ('.sql',)
+    # [END gcp_spanner_query_template_fields]
+
+    @apply_defaults
+    def __init__(self,
+                 project_id,
+                 instance_id,
+                 database_id,
+                 query,
+                 gcp_conn_id='google_cloud_default',
+                 *args, **kwargs):
+        self.instance_id = instance_id
+        self.project_id = project_id
+        self.database_id = database_id
+        self.query = query
+        self.gcp_conn_id = gcp_conn_id
+        self._validate_inputs()
+        self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id)
+        super(CloudSpannerInstanceDatabaseQueryOperator, self).__init__(*args, **kwargs)
+
+    def _validate_inputs(self):
+        if not self.project_id:
+            raise AirflowException("The required parameter 'project_id' is empty")
+        if not self.instance_id:
+            raise AirflowException("The required parameter 'instance_id' is empty")
+        if not self.database_id:
+            raise AirflowException("The required parameter 'database_id' is empty")
+        if not self.query:
+            raise AirflowException("The required parameter 'query' is empty")
+
+    def execute(self, context):
+        queries = self.query
+        if isinstance(self.query, six.string_types):
+            queries = [x.strip() for x in self.query.split(';')]
+            self.sanitize_queries(queries)
+        self.log.info("Executing DML query(-ies) on "
+                      "projects/%s/instances/%s/databases/%s",
+                      self.project_id, self.instance_id, self.database_id)
+        self.log.info(queries)
+        self._hook.execute_dml(self.project_id, self.instance_id,
+                               self.database_id, queries)
+
+    @staticmethod
+    def sanitize_queries(queries):
+        if len(queries) and queries[-1] == '':
+            del queries[-1]
diff --git a/docs/howto/operator.rst b/docs/howto/operator.rst
index 095553b3ac..221913dec0 100644
--- a/docs/howto/operator.rst
+++ b/docs/howto/operator.rst
@@ -545,6 +545,48 @@ See `Google Cloud Functions API documentation
 Google Cloud Sql Operators
 --------------------------
 
+CloudSpannerInstanceDatabaseQueryOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Executes an arbitrary DML query (INSERT, UPDATE, DELETE).
+
+For parameter definition take a look at
+:class:`~airflow.contrib.operators.gcp_spanner_operator.CloudSpannerInstanceDatabaseQueryOperator`.
+
+Arguments
+"""""""""
+
+Some arguments in the example DAG are taken from environment variables:
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_spanner.py
+    :language: python
+    :start-after: [START howto_operator_spanner_arguments]
+    :end-before: [END howto_operator_spanner_arguments]
+
+Using the operator
+""""""""""""""""""
+
+.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_spanner.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_spanner_query]
+    :end-before: [END howto_operator_spanner_query]
+
+Templating
+""""""""""
+
+.. literalinclude:: ../../airflow/contrib/operators/gcp_spanner_operator.py
+  :language: python
+  :dedent: 4
+  :start-after: [START gcp_spanner_query_template_fields]
+  :end-before: [END gcp_spanner_query_template_fields]
+
+More information
+""""""""""""""""
+
+See Google Cloud Spanner API documentation for `the DML syntax
+<https://cloud.google.com/spanner/docs/dml-syntax>`_.
+
 CloudSpannerInstanceDeployOperator
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/docs/integration.rst b/docs/integration.rst
index 0afe555309..e74d8f662b 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -642,10 +642,19 @@ Cloud Spanner
 Cloud Spanner Operators
 """""""""""""""""""""""
 
+- :ref:`CloudSpannerInstanceDatabaseQueryOperator` : executes an arbitrary DML query
+  (INSERT, UPDATE, DELETE).
 - :ref:`CloudSpannerInstanceDeployOperator` : creates a new Cloud Spanner instance or,
   if an instance with the same name exists, updates it.
 - :ref:`CloudSpannerInstanceDeleteOperator` : deletes a Cloud Spanner instance.
 
+.. _CloudSpannerInstanceDatabaseQueryOperator:
+
+CloudSpannerInstanceDatabaseQueryOperator
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: airflow.contrib.operators.gcp_spanner_operator.CloudSpannerInstanceDatabaseQueryOperator
+
 .. _CloudSpannerInstanceDeployOperator:
 
 CloudSpannerInstanceDeployOperator
diff --git a/tests/contrib/operators/test_gcp_spanner_operator.py b/tests/contrib/operators/test_gcp_spanner_operator.py
index ff2b82fd16..38ae985f26 100644
--- a/tests/contrib/operators/test_gcp_spanner_operator.py
+++ b/tests/contrib/operators/test_gcp_spanner_operator.py
@@ -22,7 +22,8 @@
 
 from airflow import AirflowException
 from airflow.contrib.operators.gcp_spanner_operator import \
-    CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator
+    CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator, \
+    CloudSpannerInstanceDatabaseQueryOperator
 from tests.contrib.operators.test_gcp_base import BaseGcpIntegrationTestCase, \
     SKIP_TEST_WARNING, GCP_SPANNER_KEY
 
@@ -37,10 +38,15 @@
 
 PROJECT_ID = 'project-id'
 INSTANCE_ID = 'instance-id'
-DB_NAME = 'db1'
+DB_ID = 'db1'
 CONFIG_NAME = 'projects/project-id/instanceConfigs/eur3'
 NODE_COUNT = '1'
 DISPLAY_NAME = 'Test Instance'
+INSERT_QUERY = "INSERT my_table1 (id, name) VALUES (1, 'One')"
+INSERT_QUERY_2 = "INSERT my_table2 (id, name) VALUES (1, 'One')"
+CREATE_QUERY = "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)"
+CREATE_QUERY_2 = "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)"
+QUERY_TYPE = "DML"
 
 
 class CloudSpannerTest(unittest.TestCase):
@@ -164,6 +170,76 @@ def test_instance_delete_ex_if_param_missing(self, project_id, instance_id, exp_
         self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err))
         mock_hook.assert_not_called()
 
+    @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook")
+    def test_instance_query(self, mock_hook):
+        mock_hook.return_value.execute_sql.return_value = None
+        op = CloudSpannerInstanceDatabaseQueryOperator(
+            project_id=PROJECT_ID,
+            instance_id=INSTANCE_ID,
+            database_id=DB_ID,
+            query=INSERT_QUERY,
+            task_id="id"
+        )
+        result = op.execute(None)
+        mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default")
+        mock_hook.return_value.execute_dml.assert_called_once_with(
+            PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY]
+        )
+        self.assertIsNone(result)
+
+    @parameterized.expand([
+        ("", INSTANCE_ID, DB_ID, INSERT_QUERY, "project_id"),
+        (PROJECT_ID, "", DB_ID, INSERT_QUERY, "instance_id"),
+        (PROJECT_ID, INSTANCE_ID, "", INSERT_QUERY, "database_id"),
+        (PROJECT_ID, INSTANCE_ID, DB_ID, "", "query"),
+    ])
+    @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook")
+    def test_instance_query_ex_if_param_missing(self, project_id, instance_id,
+                                                database_id, query, exp_msg, mock_hook):
+        with self.assertRaises(AirflowException) as cm:
+            CloudSpannerInstanceDatabaseQueryOperator(
+                project_id=project_id,
+                instance_id=instance_id,
+                database_id=database_id,
+                query=query,
+                task_id="id"
+            )
+        err = cm.exception
+        self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err))
+        mock_hook.assert_not_called()
+
+    @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook")
+    def test_instance_query_dml(self, mock_hook):
+        mock_hook.return_value.execute_dml.return_value = None
+        op = CloudSpannerInstanceDatabaseQueryOperator(
+            project_id=PROJECT_ID,
+            instance_id=INSTANCE_ID,
+            database_id=DB_ID,
+            query=INSERT_QUERY,
+            task_id="id"
+        )
+        op.execute(None)
+        mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default")
+        mock_hook.return_value.execute_dml.assert_called_once_with(
+            PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY]
+        )
+
+    @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook")
+    def test_instance_query_dml_list(self, mock_hook):
+        mock_hook.return_value.execute_dml.return_value = None
+        op = CloudSpannerInstanceDatabaseQueryOperator(
+            project_id=PROJECT_ID,
+            instance_id=INSTANCE_ID,
+            database_id=DB_ID,
+            query=[INSERT_QUERY, INSERT_QUERY_2],
+            task_id="id"
+        )
+        op.execute(None)
+        mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default")
+        mock_hook.return_value.execute_dml.assert_called_once_with(
+            PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY, INSERT_QUERY_2]
+        )
+
 
 @unittest.skipIf(
     BaseGcpIntegrationTestCase.skip_check(GCP_SPANNER_KEY), SKIP_TEST_WARNING)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services