You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2020/05/21 05:11:00 UTC

[GitHub] [airflow] samuelkhtu opened a new pull request #8942: Add SQL Branch Operator #8525

samuelkhtu opened a new pull request #8942:
URL: https://github.com/apache/airflow/pull/8942


   SQL Branch Operator allow user to execute a SQL query in any supported backend to decide which
   branch of the DAG to follow. 
   
   The SQL branch operator expects SQL query to return any of the following:
   - Boolean: True/False 
   - Integer: 0/1 
   - String: true/y/yes/1/on/false/n/no/0/off
   
   ---
   Make sure to mark the boxes below before creating PR: [x]
   
   - [x] Description above provides context of the change
   - [x] Unit tests coverage for changes (not needed for documentation changes)
   - [x] Target Github ISSUE in description if exists
   - [x] Commits follow "[How to write a good git commit message](http://chris.beams.io/posts/git-commit/)"
   - [x] Relevant documentation is updated including usage instructions.
   - [x] I will engage committers as explained in [Contribution Workflow Example](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#contribution-workflow-example).
   
   ---
   In case of fundamental code change, Airflow Improvement Proposal ([AIP](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvements+Proposals)) is needed.
   In case of a new dependency, check compliance with the [ASF 3rd Party License Policy](https://www.apache.org/legal/resolved.html#category-x).
   In case of backwards incompatible changes please leave a note in [UPDATING.md](https://github.com/apache/airflow/blob/master/UPDATING.md).
   Read the [Pull Request Guidelines](https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#pull-request-guidelines) for more information.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] boring-cyborg[bot] commented on pull request #8942: Add SQL Branch Operator #8525

Posted by GitBox <gi...@apache.org>.
boring-cyborg[bot] commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-631884466


   Congratulations on your first Pull Request and welcome to the Apache Airflow community! If you have any issues or are unsure about any anything please check our Contribution Guide (https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst)
   Here are some useful points:
   - Pay attention to the quality of your code (flake8, pylint and type annotations). Our [pre-commits]( https://github.com/apache/airflow/blob/master/STATIC_CODE_CHECKS.rst#prerequisites-for-pre-commit-hooks) will help you with that.
   - In case of a new feature add useful documentation (in docstrings or in `docs/` directory). Adding a new operator? Check this short [guide](https://github.com/apache/airflow/blob/master/docs/howto/custom-operator.rst) Consider adding an example DAG that shows how users should use it.
   - Consider using [Breeze environment](https://github.com/apache/airflow/blob/master/BREEZE.rst) for testing locally, itโ€™s a heavy docker but it ships with a working Airflow and a lot of integrations.
   - Be patient and persistent. It might take some time to get a review or get the final approval from Committers.
   - Be sure to read the [Airflow Coding style]( https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#coding-style-and-best-practices).
   Apache Airflow is a community-driven project and together we are making it better ๐Ÿš€.
   In case of doubts contact the developers at:
   Mailing List: dev@airflow.apache.org
   Slack: https://apache-airflow-slack.herokuapp.com/
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk merged pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk merged pull request #8942:
URL: https://github.com/apache/airflow/pull/8942


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-634753579


   > @samuelkhtu we can do it in a separate PR. This will allow us to build a better git history.
   
   Thanks @mik-laj and @eladkal , can someone help and approve this PR? I am ready to move on to the next issue. Thanks.
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429663340



##########
File path: tests/operators/test_sql_branch_operator.py
##########
@@ -0,0 +1,482 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# import os
+# import unittest
+# from unittest import mock

Review comment:
       Roger that! Removed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432884263



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Hi @potiuk , If I added the following code to check if the hook is DbApiHook instance, my unit test's mock will break. Any suggestion? I searched the Airflow code base but I couldn't find an example to resolve this issue. Thanks!
   
   ```Python
   if not isinstance(self._hook, DbApiHook):
       raise AirflowException(
           "Unexpected type returned '%s' expected DbApiHook" % type(self._hook))
   
   # In unit test, my test will fail with
   "airflow.exceptions.AirflowException: Unexpected type returned '<class 'unittest.mock.MagicMock'>' expected DbApiHook"
   ```
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-632474847


   Hi @eladkal , the PR is ready to review again. Please take a look when you have a chance. Thanks!


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-637024265


   > Nice @samuelkhtu !
   
   Thank you everyone's help! @potiuk @mik-laj @eladkal 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r428794914



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,
+        follow_task_ids_if_false: str,
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        records = self._hook.get_records(self.sql, self.parameters)

Review comment:
       Yes, Thank you. This make perfect sense.

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,

Review comment:
       Thanks, I will update the type hints.

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,
+        follow_task_ids_if_false: str,

Review comment:
       Absolutely




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432887660



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       I also checked other operators to see if any of them verify the DbApiHook return type but I couldn't find any example of that check.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429570989



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        record = self._hook.get_first(self.sql, self.parameters)
+        if not record:
+            raise AirflowException(
+                "No rows returned from sql query. Operator expected True or False return value."
+            )
+
+        if isinstance(record, list):
+            query_result = record[0][0]

Review comment:
       I get an error related to this line
   ```
     File "/home/airflow/dags/testbranch.py", line 126, in execute
       query_result = record[0][0]
   TypeError: 'bool' object is not subscriptable
   ```
   
   Did it work for you when you tested against other DBs?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429615250



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+

Review comment:
       Not sure you need to have blank lines between each params. check other operators, for example:
   https://github.com/apache/airflow/blob/1d36b0303b8632fce6de78ca4e782ae26ee06fea/airflow/providers/mysql/operators/mysql.py#L27




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r433322424



##########
File path: tests/operators/test_sql_branch_operator.py
##########
@@ -0,0 +1,479 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
+
+TEST_DAG_ID = "unit_test_sql_dag"
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+SUPPORTED_TRUE_VALUES = [
+    ["True"],
+    ["true"],
+    ["1"],
+    ["on"],
+    [1],
+    True,
+    "true",
+    "1",
+    "on",
+    1,
+]
+SUPPORTED_FALSE_VALUES = [
+    ["False"],
+    ["false"],
+    ["0"],
+    ["off"],
+    [0],
+    False,
+    "false",
+    "0",
+    "off",
+    0,
+]
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+    """
+    Test for SQL Branch Operator
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def setUp(self):
+        super().setUp()
+        self.dag = DAG(
+            "sql_branch_operator_test",
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+        # self.dag = DAG(TEST_DAG_ID, default_args=args)

Review comment:
       do we need this comment?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] mik-laj commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
mik-laj commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-634582759


   @samuelkhtu we can do it in a separate PR.  This will allow us to build a better git history. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429615250



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+

Review comment:
       Not sure you need to have blank lines between each params. check others, for example:
   https://github.com/apache/airflow/blob/1d36b0303b8632fce6de78ca4e782ae26ee06fea/airflow/providers/mysql/operators/mysql.py#L27




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429615250



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+

Review comment:
       Not sure you need to have blank lines between each params check other:
   https://github.com/apache/airflow/blob/1d36b0303b8632fce6de78ca4e782ae26ee06fea/airflow/providers/mysql/operators/mysql.py#L27




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432845070



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {

Review comment:
       Should you make it CONSTANT?

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Maybe it's a good idea to check if the hook is DbApiHook instance? I think it's even better check than checking the connection types.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-633603940


   > LGTM.
   > 
   > One last observation about the file name `sql_branch_operator.py`. By accepting this PR Airflow will have 2 sql related operator files in the core (`check_operator.py`, `sql_branch_operator.py`) so maybe the file name should be changed `sql_branch_operator.py` -> `sql.py` so that In the future [check_operator.py](https://github.com/apache/airflow/blob/master/airflow/operators/check_operator.py) can be deprecated by move the operators into `sql.py` Similar to the deprecation of [python_operator.py](https://github.com/apache/airflow/blob/master/airflow/operators/python_operator.py) by moving classes to [python.py](https://github.com/apache/airflow/blob/master/airflow/operators/python.py)
   > Lets wait to see what others think of that.
   
   Thanks @eladkal. If everyone decided to merge the operators into one, shouldn't we use a separate PR instead? I am new to the community, so I am not sure what is the next step? Is there anyone I need to talk to in order to finish this PR? Thanks. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r433044093



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       OK. Fine for me. Indeed that is quite OK this way :)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429663469



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+

Review comment:
       Removed blank lines




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r428493813



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,
+        follow_task_ids_if_false: str,
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        records = self._hook.get_records(self.sql, self.parameters)

Review comment:
       why `get_records`? are you expecting to more than 1 row?
   Shouldn't it use `get_first`?
   ```suggestion
           records = self._hook.get_first(self.sql, self.parameters)
   ```

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,

Review comment:
       it can be list of str if you have more than one task

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: str,
+        follow_task_ids_if_false: str,

Review comment:
       same




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r433044218



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Can you please rebase to latest master? That should fix the failure problem.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432884511



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Here is some info on that: https://github.com/apache/airflow/blob/master/CONTRIBUTING.rst#how-to-rebase-pr




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432884462



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Just rebase it to latest master please. We had some problems with kubernetes tests - they should be largely solved now (I am working on final fix) but at least it should not fail now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429664120



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )

Review comment:
       Sounds good. Added check for SQL parameter as well.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429614232



##########
File path: tests/operators/test_sql_branch_operator.py
##########
@@ -0,0 +1,482 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# import os
+# import unittest
+# from unittest import mock

Review comment:
       no need for that

##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+

Review comment:
       Not sure you need to have blank lines between each params check:




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-633279449


   Hi @eladkal , Thank you again for the review. The latest commit should address the last set of comments. Please take a look when you have a chance. Thanks.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] boring-cyborg[bot] commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
boring-cyborg[bot] commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-637023543


   Awesome work, congrats on your first merged pull request!
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432884495



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {

Review comment:
       I moved and converted the "allowed_conn_type" variable to ALLOWED_CONN_TYPE constant. I also moved the constant outside of the class. What do you think? Thanks
   
   ```Python
   ALLOWED_CONN_TYPE = {
       "google_cloud_platform",
       "jdbc",
       "mssql",
       "mysql",
       "odbc",
       "oracle",
       "postgres",
       "presto",
       "sqlite",
       "vertica",
   }
   
   
   class BranchSqlOperator(BaseOperator, SkipMixin):
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r433382917



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Thank you so much for reviewing this PR. Yes, I will rebase my PR. Thanks!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] potiuk commented on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
potiuk commented on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-637023708


    Nice @samuelkhtu !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r433387332



##########
File path: tests/operators/test_sql_branch_operator.py
##########
@@ -0,0 +1,479 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.operators.sql_branch_operator import BranchSqlOperator
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+from tests.providers.apache.hive import TestHiveEnvironment
+
+TEST_DAG_ID = "unit_test_sql_dag"
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+SUPPORTED_TRUE_VALUES = [
+    ["True"],
+    ["true"],
+    ["1"],
+    ["on"],
+    [1],
+    True,
+    "true",
+    "1",
+    "on",
+    1,
+]
+SUPPORTED_FALSE_VALUES = [
+    ["False"],
+    ["false"],
+    ["0"],
+    ["off"],
+    [0],
+    False,
+    "false",
+    "0",
+    "off",
+    0,
+]
+
+
+class TestSqlBranch(TestHiveEnvironment, unittest.TestCase):
+    """
+    Test for SQL Branch Operator
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def setUp(self):
+        super().setUp()
+        self.dag = DAG(
+            "sql_branch_operator_test",
+            default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+        # self.dag = DAG(TEST_DAG_ID, default_args=args)

Review comment:
       Comment removed. I also removed unused TEST_DAG_ID constant from the test. Thanks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu edited a comment on pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu edited a comment on pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#issuecomment-633279449


   Hi @eladkal , Thank you again for the review. The latest commit included changes for the last set of comments. Please take a look when you have a chance. Thanks.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429583242



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )
+
+        self.log.info(
+            "Executing: %s (with parameters %s) with connection: %s",
+            self.sql,
+            self.parameters,
+            self._hook,
+        )
+        record = self._hook.get_first(self.sql, self.parameters)
+        if not record:
+            raise AirflowException(
+                "No rows returned from sql query. Operator expected True or False return value."
+            )
+
+        if isinstance(record, list):
+            query_result = record[0][0]

Review comment:
       @eladkal  Thank you for review again. This is unexpected. I found an issue with the unit test and fixed this issue as well. The PR should be ready to review again.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] eladkal commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
eladkal commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r429615642



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+
+    :param database: name of database which overwrite defined one in connection
+
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+
+
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+
+        if self._hook is None:
+            raise AirflowException(
+                "Failed to establish connection to '%s'" % self.conn_id
+            )
+
+        if self.follow_task_ids_if_true is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_true"
+            )
+
+        if self.follow_task_ids_if_false is None:
+            raise AirflowException(
+                "Expected task id or task ids assigned to follow_task_ids_if_false"
+            )

Review comment:
       If you check for mandatory arguments why not check also for `sql` param?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] samuelkhtu commented on a change in pull request #8942: #8525 Add SQL Branch Operator

Posted by GitBox <gi...@apache.org>.
samuelkhtu commented on a change in pull request #8942:
URL: https://github.com/apache/airflow/pull/8942#discussion_r432887517



##########
File path: airflow/operators/sql_branch_operator.py
##########
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from distutils.util import strtobool
+from typing import Dict, Iterable, List, Mapping, Optional, Union
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.utils.decorators import apply_defaults
+
+
+class BranchSqlOperator(BaseOperator, SkipMixin):
+    """
+    Executes sql code in a specific database
+
+    :param sql: the sql code to be executed. (templated)
+    :type sql: Can receive a str representing a sql statement or reference to a template file.
+               Template reference are recognized by str ending in '.sql'.
+               Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+               or string (true/y/yes/1/on/false/n/no/0/off).
+    :param follow_task_ids_if_true: task id or task ids to follow if query return true
+    :type follow_task_ids_if_true: str or list
+    :param follow_task_ids_if_false: task id or task ids to follow if query return true
+    :type follow_task_ids_if_false: str or list
+    :param conn_id: reference to a specific database
+    :type conn_id: str
+    :param database: name of database which overwrite defined one in connection
+    :param parameters: (optional) the parameters to render the SQL query with.
+    :type parameters: mapping or iterable
+    """
+
+    template_fields = ("sql",)
+    template_ext = (".sql",)
+    ui_color = "#a22034"
+    ui_fgcolor = "#F7F7F7"
+
+    @apply_defaults
+    def __init__(
+        self,
+        sql: str,
+        follow_task_ids_if_true: List[str],
+        follow_task_ids_if_false: List[str],
+        conn_id: str = "default_conn_id",
+        database: Optional[str] = None,
+        parameters: Optional[Union[Mapping, Iterable]] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        super().__init__(*args, **kwargs)
+        self.conn_id = conn_id
+        self.sql = sql
+        self.parameters = parameters
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.database = database
+        self._hook = None
+
+    def _get_hook(self):
+        self.log.debug("Get connection for %s", self.conn_id)
+        conn = BaseHook.get_connection(self.conn_id)
+
+        allowed_conn_type = {
+            "google_cloud_platform",
+            "jdbc",
+            "mssql",
+            "mysql",
+            "odbc",
+            "oracle",
+            "postgres",
+            "presto",
+            "sqlite",
+            "vertica",
+        }
+        if conn.conn_type not in allowed_conn_type:
+            raise AirflowException(
+                "The connection type is not supported by BranchSqlOperator. "
+                + "Supported connection types: {}".format(list(allowed_conn_type))
+            )
+
+        if not self._hook:
+            self._hook = conn.get_hook()
+            if self.database:
+                self._hook.schema = self.database
+
+        return self._hook
+
+    def execute(self, context: Dict):
+        # get supported hook
+        self._hook = self._get_hook()
+

Review comment:
       Hi @potiuk , I haven't started the breeze integrated test yet. I only run the unit test locally. So, I think this is a "mock" issue.
   
   ```Python
   pytest ../airflow/tests/operators/test_sql_branch_operator.py
   ```
   In the unit test, I used mock connection (Similar to the existing unit "test test_sql_sensor")
   ``` Python
   mock_hook.get_connection("mysql_default").conn_type = "mysql"
   ```
   Therefore, during my local pytest run, the new "isinstance(self._hook, DbApiHook)" check failed and return error.
   
   ```Python
   def execute(self, context: Dict):
           # get supported hook
           self._hook = self._get_hook()
       
           if not isinstance(self._hook, DbApiHook):
               raise AirflowException(
   >               "Unexpected type returned '%s' expected DbApiHook" % type(self._hook)
               )
   E           airflow.exceptions.AirflowException: Unexpected type returned '<class 'unittest.mock.MagicMock'>' expected DbApiHook
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org