You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/04 11:24:14 UTC

[airflow] branch main updated: Add support to specify kernel name in PapermillOperator (#20035)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d3f4456  Add support to specify kernel name in PapermillOperator (#20035)
d3f4456 is described below

commit d3f445636394743b9298cae99c174cb4ac1fc30c
Author: Le Minh Thong <55...@users.noreply.github.com>
AuthorDate: Sat Dec 4 18:23:35 2021 +0700

    Add support to specify kernel name in PapermillOperator (#20035)
---
 airflow/providers/papermill/operators/papermill.py    |  8 +++++++-
 tests/providers/papermill/operators/test_papermill.py | 13 +++++++++++--
 2 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py
index ebbda66..7c2e8d3 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -44,11 +44,14 @@ class PapermillOperator(BaseOperator):
     :type output_nb: str
     :param parameters: the notebook parameters to set
     :type parameters: dict
+    :param kernel_name: (optional) name of kernel to execute the notebook against
+        (ignores kernel name in the notebook document metadata)
+    :type kernel_name: str
     """
 
     supports_lineage = True
 
-    template_fields = ('input_nb', 'output_nb', 'parameters')
+    template_fields = ('input_nb', 'output_nb', 'parameters', 'kernel_name')
 
     def __init__(
         self,
@@ -56,6 +59,7 @@ class PapermillOperator(BaseOperator):
         input_nb: Optional[str] = None,
         output_nb: Optional[str] = None,
         parameters: Optional[Dict] = None,
+        kernel_name: Optional[str] = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -63,6 +67,7 @@ class PapermillOperator(BaseOperator):
         self.input_nb = input_nb
         self.output_nb = output_nb
         self.parameters = parameters
+        self.kernel_name = kernel_name
         if input_nb:
             self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters))
         if output_nb:
@@ -79,4 +84,5 @@ class PapermillOperator(BaseOperator):
                 parameters=item.parameters,
                 progress_bar=False,
                 report_mode=True,
+                kernel_name=self.kernel_name,
             )
diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py
index 7185932..24d44e3 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -30,6 +30,7 @@ class TestPapermillOperator(unittest.TestCase):
     def test_execute(self, mock_papermill):
         in_nb = "/tmp/does_not_exist"
         out_nb = "/tmp/will_not_exist"
+        kernel_name = "python3"
         parameters = {"msg": "hello_world", "train": 1}
 
         op = PapermillOperator(
@@ -37,14 +38,20 @@ class TestPapermillOperator(unittest.TestCase):
             output_nb=out_nb,
             parameters=parameters,
             task_id="papermill_operator_test",
+            kernel_name=kernel_name,
             dag=None,
         )
 
-        op.pre_execute(context={})  # make sure to have the inlets
+        op.pre_execute(context={})  # Make sure to have the inlets
         op.execute(context={})
 
         mock_papermill.execute_notebook.assert_called_once_with(
-            in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True
+            in_nb,
+            out_nb,
+            parameters=parameters,
+            kernel_name=kernel_name,
+            progress_bar=False,
+            report_mode=True,
         )
 
     def test_render_template(self):
@@ -56,6 +63,7 @@ class TestPapermillOperator(unittest.TestCase):
             input_nb="/tmp/{{ dag.dag_id }}.ipynb",
             output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
             parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
+            kernel_name="python3",
             dag=dag,
         )
 
@@ -66,3 +74,4 @@ class TestPapermillOperator(unittest.TestCase):
         assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb')
         assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb')
         assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters')
+        assert "python3" == getattr(operator, 'kernel_name')