You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/18 04:21:39 UTC

[GitHub] stale[bot] closed pull request #3115: [AIRFLOW-2193] Add ROperator for using R

stale[bot] closed pull request #3115: [AIRFLOW-2193] Add ROperator for using R
URL: https://github.com/apache/incubator-airflow/pull/3115
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/r_operator.py b/airflow/contrib/operators/r_operator.py
new file mode 100644
index 0000000000..9974061892
--- /dev/null
+++ b/airflow/contrib/operators/r_operator.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from builtins import bytes
+import os
+from tempfile import NamedTemporaryFile
+
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+from airflow.utils.file import TemporaryDirectory
+
+import rpy2.robjects as robjects
+from rpy2.rinterface import RRuntimeError
+
+
+class ROperator(BaseOperator):
+    """
+    Execute an R script or command
+
+    :param r_command: The command or a reference to an R script (must have
+        '.r' extension) to be executed (templated)
+    :type r_command: string
+    :param xcom_push: If xcom_push is True (default: False), the last line
+        written to stdout will also be pushed to an XCom (key 'return_value')
+        when the R command completes.
+    :type xcom_push: bool
+    :param output_encoding: encoding output from R (default: 'utf-8')
+    :type output_encoding: string
+
+    """
+
+    template_fields = ('r_command',)
+    template_ext = ('.r', '.R')
+    ui_color = '#C8D5E6'
+
+    @apply_defaults
+    def __init__(
+            self,
+            r_command,
+            xcom_push=False,
+            output_encoding='utf-8',
+            *args, **kwargs):
+
+        super(ROperator, self).__init__(*args, **kwargs)
+        self.r_command = r_command
+        self.xcom_push = xcom_push
+        self.output_encoding = output_encoding
+
+    def execute(self, context):
+        """
+        Execute the R command or script in a temporary directory
+        """
+
+        with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
+            with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f:
+
+                f.write(bytes(self.r_command, 'utf_8'))
+                f.flush()
+                fname = f.name
+                script_location = os.path.abspath(fname)
+
+                self.log.info("Temporary script location: %s", script_location)
+                self.log.info("Running command(s):\n%s", self.r_command)
+
+                try:
+                    res = robjects.r.source(fname, echo=False)
+                except RRuntimeError as e:
+                    self.log.error("Received R error: %s", e)
+                    res = None
+
+        if self.xcom_push and res:
+            # This will be a pickled rpy2.robjects.vectors.ListVector
+            self.log.info('Pushing last line of output to Xcom: \n %s', res)
+            return res
diff --git a/tests/contrib/operators/test_r_operator.py b/tests/contrib/operators/test_r_operator.py
new file mode 100644
index 0000000000..13205d9c1e
--- /dev/null
+++ b/tests/contrib/operators/test_r_operator.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function, unicode_literals
+
+import os
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators.r_operator import ROperator
+from airflow.models import TaskInstance
+from airflow.utils import timezone
+
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+
+
+class ROperatorTest(unittest.TestCase):
+    """Test the ROperator"""
+
+    def setUp(self):
+        super(ROperatorTest, self).setUp()
+        configuration.load_test_config()
+        self.dag = DAG(
+            'test_roperator_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE
+            },
+            schedule_interval='@once'
+        )
+
+        self.xcom_test_str = 'Hello Airflow'
+        self.task_xcom = ROperator(
+            task_id='test_r_xcom',
+            r_command='cat("Ignored Line\n{}")'.format(self.xcom_test_str),
+            xcom_push=True,
+            dag=self.dag
+        )
+
+    def test_invalid_rscript_bin(self):
+        """Fail if invalid rscript_bin supplied"""
+
+        try:
+            expected_error = FileNotFoundError
+        except NameError:
+            # py2
+            expected_error = OSError
+
+        task = ROperator(
+            task_id='test_r_bad_rscript',
+            r_command='print(Sys.Date())',
+            rscript_bin='somebadrscript',
+            dag=self.dag
+        )
+
+        self.assertIsNotNone(task)
+
+        ti = TaskInstance(task=task, execution_date=timezone.utcnow())
+
+        with self.assertRaises(expected_error):
+            ti.run()
+
+    def test_xcom_output(self):
+        """Test whether Xcom output is produced using last line"""
+
+        self.task_xcom.xcom_push = True
+
+        ti = TaskInstance(
+            task=self.task_xcom,
+            execution_date=timezone.utcnow()
+        )
+
+        ti.run()
+        self.assertIsNotNone(ti.duration)
+
+        self.assertEqual(
+            ti.xcom_pull(task_ids=self.task_xcom.task_id, key='return_value'),
+            self.xcom_test_str
+        )
+
+    def test_xcom_none(self):
+        """Test whether no Xcom output is produced when push=False"""
+
+        self.task_xcom.xcom_push = False
+
+        ti = TaskInstance(
+            task=self.task_xcom,
+            execution_date=timezone.utcnow(),
+        )
+
+        ti.run()
+        self.assertIsNotNone(ti.duration)
+        self.assertIsNone(ti.xcom_pull(task_ids=self.task_xcom.task_id))
+
+    def test_env_vars(self):
+        """Test whether environment is passed properly"""
+
+        test_var = 'TEST_VALUE_X'
+        test_str = 'Hello Airflow'
+
+        task = ROperator(
+            task_id='test_env_vars',
+            r_command='cat(Sys.getenv("{}"))'.format(test_var),
+            env={test_var: test_str, "PATH": os.environ['PATH']},
+            xcom_push=True,
+            dag=self.dag
+        )
+
+        ti = TaskInstance(task=task, execution_date=timezone.utcnow())
+
+        ti.run()
+        self.assertIsNotNone(ti.duration)
+
+        self.assertEqual(
+            ti.xcom_pull(task_ids=task.task_id, key='return_value'),
+            test_str
+        )
+
+    def test_command_template(self):
+        """Test whether templating works properly with r_command"""
+
+        task = ROperator(
+            task_id='test_cmd_template',
+            r_command='cat("{{ ds }}")',
+            dag=self.dag
+        )
+
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.render_templates()
+
+        self.assertEqual(
+            ti.task.r_command,
+            'cat("{}")'.format(DEFAULT_DATE.date().isoformat())
+        )
+
+    def test_env_templates(self):
+        """Test whether templating works properly with env vars"""
+
+        test_var = 'TEST_CURR_DATE'
+
+        task = ROperator(
+            task_id='test_env_template',
+            r_command='cat(Sys.getenv("{}"))'.format(test_var),
+            env={test_var: "{{ ds }}"},
+            xcom_push=True,
+            dag=self.dag
+        )
+
+        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti.run()
+
+        self.assertEqual(
+            ti.xcom_pull(task_ids=task.task_id, key='return_value'),
+            DEFAULT_DATE.date().isoformat()
+        )
+
+
+if __name__ == '__main__':
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services