You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/02/01 17:28:19 UTC

[airflow] branch main updated: Fixes Docker xcom functionality (#21175)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f4a3d4d Fixes Docker xcom functionality (#21175)
2f4a3d4d is described below

commit 2f4a3d4d4008a95fc36971802c514fef68e8a5d4
Author: D. Ferruzzi <fe...@amazon.com>
AuthorDate: Tue Feb 1 09:27:31 2022 -0800

    Fixes Docker xcom functionality (#21175)
    
    * Fixes Docker xcom functionality
---
 airflow/providers/docker/operators/docker.py    | 50 +++++++++++++------------
 tests/providers/docker/operators/test_docker.py | 33 ++++++++++++++--
 2 files changed, 57 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py
index 2126e16..ceb40f1 100644
--- a/airflow/providers/docker/operators/docker.py
+++ b/airflow/providers/docker/operators/docker.py
@@ -275,32 +275,40 @@ class DockerOperator(BaseOperator):
             working_dir=self.working_dir,
             tty=self.tty,
         )
-        lines = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True)
+        logstream = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True)
         try:
             self.cli.start(self.container['Id'])
 
-            line = ''
-            res_lines = []
-            return_value = None
-            for line in lines:
-                if hasattr(line, 'decode'):
+            log_lines = []
+            for log_chunk in logstream:
+                if hasattr(log_chunk, 'decode'):
                     # Note that lines returned can also be byte sequences so we have to handle decode here
-                    line = line.decode('utf-8')
-                line = line.strip()
-                res_lines.append(line)
-                self.log.info(line)
+                    log_chunk = log_chunk.decode('utf-8', errors='surrogateescape')
+                log_chunk = log_chunk.strip()
+                log_lines.append(log_chunk)
+                self.log.info("%s", log_chunk)
+
             result = self.cli.wait(self.container['Id'])
             if result['StatusCode'] != 0:
-                res_lines = "\n".join(res_lines)
-                raise AirflowException('docker container failed: ' + repr(result) + f"lines {res_lines}")
-            if self.retrieve_output and not return_value:
-                return_value = self._attempt_to_retrieve_result()
-            ret = None
+                joined_log_lines = "\n".join(log_lines)
+                raise AirflowException(f'Docker container failed: {repr(result)} lines {joined_log_lines}')
+
             if self.retrieve_output:
-                ret = return_value
+                return self._attempt_to_retrieve_result()
             elif self.do_xcom_push:
-                ret = self._get_return_value_from_logs(res_lines, line)
-            return ret
+                log_parameters = {
+                    'container': self.container['Id'],
+                    'stdout': True,
+                    'stderr': True,
+                    'stream': True,
+                }
+
+                return (
+                    self.cli.logs(**log_parameters)
+                    if self.xcom_all
+                    else self.cli.logs(**log_parameters, tail=1)
+                )
+            return None
         finally:
             if self.auto_remove:
                 self.cli.remove_container(self.container['Id'])
@@ -326,14 +334,10 @@ class DockerOperator(BaseOperator):
             return lib.loads(file.read())
 
         try:
-            return_value = copy_from_docker(self.container['Id'], self.retrieve_output_path)
-            return return_value
+            return copy_from_docker(self.container['Id'], self.retrieve_output_path)
         except APIError:
             return None
 
-    def _get_return_value_from_logs(self, res_lines, line):
-        return res_lines if self.xcom_all else line
-
     def execute(self, context: 'Context') -> Optional[str]:
         self.cli = self._get_cli()
         if not self.cli:
diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py
index 58ddc2f..cda1cf2 100644
--- a/tests/providers/docker/operators/test_docker.py
+++ b/tests/providers/docker/operators/test_docker.py
@@ -47,10 +47,17 @@ class TestDockerOperator(unittest.TestCase):
         self.client_mock = mock.Mock(spec=APIClient)
         self.client_mock.create_container.return_value = {'Id': 'some_id'}
         self.client_mock.images.return_value = []
-        self.client_mock.attach.return_value = ['container log 1', 'container log 2']
         self.client_mock.pull.return_value = {"status": "pull log"}
         self.client_mock.wait.return_value = {"StatusCode": 0}
         self.client_mock.create_host_config.return_value = mock.Mock()
+        self.log_messages = ['container log 1', 'container log 2']
+        self.client_mock.attach.return_value = self.log_messages
+
+        # If logs() is called with tail then only return the last value, otherwise return the whole log.
+        self.client_mock.logs.side_effect = (
+            lambda **kwargs: self.log_messages[-kwargs['tail']] if 'tail' in kwargs else self.log_messages
+        )
+
         self.client_class_patcher = mock.patch(
             'airflow.providers.docker.operators.docker.APIClient',
             return_value=self.client_mock,
@@ -117,6 +124,9 @@ class TestDockerOperator(unittest.TestCase):
         self.client_mock.attach.assert_called_once_with(
             container='some_id', stdout=True, stderr=True, stream=True
         )
+        self.client_mock.logs.assert_called_once_with(
+            container='some_id', stdout=True, stderr=True, stream=True, tail=1
+        )
         self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
         self.client_mock.wait.assert_called_once_with('some_id')
         assert (
@@ -179,6 +189,9 @@ class TestDockerOperator(unittest.TestCase):
         self.client_mock.attach.assert_called_once_with(
             container='some_id', stdout=True, stderr=True, stream=True
         )
+        self.client_mock.logs.assert_called_once_with(
+            container='some_id', stdout=True, stderr=True, stream=True, tail=1
+        )
         self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
         self.client_mock.wait.assert_called_once_with('some_id')
         assert (
@@ -283,6 +296,9 @@ class TestDockerOperator(unittest.TestCase):
         self.client_mock.attach.assert_called_once_with(
             container='some_id', stdout=True, stderr=True, stream=True
         )
+        self.client_mock.logs.assert_called_once_with(
+            container='some_id', stdout=True, stderr=True, stream=True, tail=1
+        )
         self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
         self.client_mock.wait.assert_called_once_with('some_id')
         assert (
@@ -339,11 +355,22 @@ class TestDockerOperator(unittest.TestCase):
             print_exception_mock.assert_not_called()
 
     def test_execute_container_fails(self):
-        self.client_mock.wait.return_value = {"StatusCode": 1}
+        failed_msg = {'StatusCode': 1}
+        log_line = ['unicode container log 😁   ', b'byte string container log']
+        expected_message = 'Docker container failed: {failed_msg} lines {expected_log_output}'
+        self.client_mock.attach.return_value = log_line
+        self.client_mock.wait.return_value = failed_msg
+
         operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')
-        with pytest.raises(AirflowException):
+
+        with pytest.raises(AirflowException) as raised_exception:
             operator.execute(None)
 
+        assert str(raised_exception.value) == expected_message.format(
+            failed_msg=failed_msg,
+            expected_log_output=f'{log_line[0].strip()}\n{log_line[1].decode("utf-8")}',
+        )
+
     def test_auto_remove_container_fails(self):
         self.client_mock.wait.return_value = {"StatusCode": 1}
         operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest', auto_remove=True)