You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/04 11:24:14 UTC
[airflow] branch main updated: Add support to specify kernel name in PapermillOperator (#20035)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d3f4456 Add support to specify kernel name in PapermillOperator (#20035)
d3f4456 is described below
commit d3f445636394743b9298cae99c174cb4ac1fc30c
Author: Le Minh Thong <55...@users.noreply.github.com>
AuthorDate: Sat Dec 4 18:23:35 2021 +0700
Add support to specify kernel name in PapermillOperator (#20035)
---
airflow/providers/papermill/operators/papermill.py | 8 +++++++-
tests/providers/papermill/operators/test_papermill.py | 13 +++++++++++--
2 files changed, 18 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py
index ebbda66..7c2e8d3 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -44,11 +44,14 @@ class PapermillOperator(BaseOperator):
:type output_nb: str
:param parameters: the notebook parameters to set
:type parameters: dict
+ :param kernel_name: (optional) name of kernel to execute the notebook against
+ (ignores kernel name in the notebook document metadata)
+ :type kernel_name: str
"""
supports_lineage = True
- template_fields = ('input_nb', 'output_nb', 'parameters')
+ template_fields = ('input_nb', 'output_nb', 'parameters', 'kernel_name')
def __init__(
self,
@@ -56,6 +59,7 @@ class PapermillOperator(BaseOperator):
input_nb: Optional[str] = None,
output_nb: Optional[str] = None,
parameters: Optional[Dict] = None,
+ kernel_name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -63,6 +67,7 @@ class PapermillOperator(BaseOperator):
self.input_nb = input_nb
self.output_nb = output_nb
self.parameters = parameters
+ self.kernel_name = kernel_name
if input_nb:
self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters))
if output_nb:
@@ -79,4 +84,5 @@ class PapermillOperator(BaseOperator):
parameters=item.parameters,
progress_bar=False,
report_mode=True,
+ kernel_name=self.kernel_name,
)
diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py
index 7185932..24d44e3 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -30,6 +30,7 @@ class TestPapermillOperator(unittest.TestCase):
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
out_nb = "/tmp/will_not_exist"
+ kernel_name = "python3"
parameters = {"msg": "hello_world", "train": 1}
op = PapermillOperator(
@@ -37,14 +38,20 @@ class TestPapermillOperator(unittest.TestCase):
output_nb=out_nb,
parameters=parameters,
task_id="papermill_operator_test",
+ kernel_name=kernel_name,
dag=None,
)
- op.pre_execute(context={}) # make sure to have the inlets
+ op.pre_execute(context={}) # Make sure to have the inlets
op.execute(context={})
mock_papermill.execute_notebook.assert_called_once_with(
- in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True
+ in_nb,
+ out_nb,
+ parameters=parameters,
+ kernel_name=kernel_name,
+ progress_bar=False,
+ report_mode=True,
)
def test_render_template(self):
@@ -56,6 +63,7 @@ class TestPapermillOperator(unittest.TestCase):
input_nb="/tmp/{{ dag.dag_id }}.ipynb",
output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
+ kernel_name="python3",
dag=dag,
)
@@ -66,3 +74,4 @@ class TestPapermillOperator(unittest.TestCase):
assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb')
assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb')
assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters')
+ assert "python3" == getattr(operator, 'kernel_name')