You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/12/20 08:05:14 UTC

[airflow] branch main updated: Add link for EMR Steps Sensor logs (#28180)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fefcb1d567 Add link for EMR Steps Sensor logs (#28180)
fefcb1d567 is described below

commit fefcb1d567d8d605f7ec9b7d408831d656736541
Author: Syed Hussaain <10...@users.noreply.github.com>
AuthorDate: Tue Dec 20 00:05:05 2022 -0800

    Add link for EMR Steps Sensor logs (#28180)
---
 airflow/providers/amazon/aws/hooks/s3.py           |  2 +-
 airflow/providers/amazon/aws/links/emr.py          |  8 +++++++
 airflow/providers/amazon/aws/sensors/emr.py        | 26 ++++++++++++++++------
 airflow/providers/amazon/provider.yaml             |  1 +
 tests/providers/amazon/aws/hooks/test_s3.py        |  8 +++++++
 .../providers/amazon/aws/sensors/test_emr_base.py  |  3 ++-
 .../amazon/aws/sensors/test_emr_job_flow.py        |  9 +++++---
 7 files changed, 45 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index e5b1bad580..1e2235f24a 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -152,7 +152,7 @@ class S3Hook(AwsBaseHook):
         :return: the parsed bucket name and key
         """
         format = s3url.split("//")
-        if format[0].lower() == "s3:":
+        if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
             parsed_url = urlsplit(s3url)
             if not parsed_url.netloc:
                 raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py
index aa739567fb..83e190f663 100644
--- a/airflow/providers/amazon/aws/links/emr.py
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -27,3 +27,11 @@ class EmrClusterLink(BaseAwsLink):
     format_str = (
         BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
     )
+
+
+class EmrLogsLink(BaseAwsLink):
+    """Helper class for constructing AWS EMR Logs Link"""
+
+    name = "EMR Cluster Logs"
+    key = "emr_logs"
+    format_str = BASE_AWS_CONSOLE_LINK + "/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index 04e2897716..d1cd0949e0 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Iterable, Sequence
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.links.emr import EmrLogsLink
 from airflow.sensors.base import BaseSensorOperator, poke_mode_only
 
 if TYPE_CHECKING:
@@ -61,7 +63,7 @@ class EmrBaseSensor(BaseSensorOperator):
         return self.hook
 
     def poke(self, context: Context):
-        response = self.get_emr_response()
+        response = self.get_emr_response(context=context)
 
         if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
             self.log.info("Bad HTTP response: %s", response)
@@ -78,7 +80,7 @@ class EmrBaseSensor(BaseSensorOperator):
 
         return False
 
-    def get_emr_response(self) -> dict[str, Any]:
+    def get_emr_response(self, context: Context) -> dict[str, Any]:
         """
         Make an API call with boto3 and get response.
 
@@ -329,7 +331,7 @@ class EmrNotebookExecutionSensor(EmrBaseSensor):
         self.target_states = target_states or self.COMPLETED_STATES
         self.failed_states = failed_states or self.FAILURE_STATES
 
-    def get_emr_response(self) -> dict[str, Any]:
+    def get_emr_response(self, context: Context) -> dict[str, Any]:
         emr_client = self.get_hook().get_conn()
         self.log.info("Poking notebook %s", self.notebook_execution_id)
 
@@ -382,6 +384,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
 
     template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
     template_ext: Sequence[str] = ()
+    operator_extra_links = (EmrLogsLink(),)
 
     def __init__(
         self,
@@ -396,7 +399,7 @@ class EmrJobFlowSensor(EmrBaseSensor):
         self.target_states = target_states or ["TERMINATED"]
         self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
 
-    def get_emr_response(self) -> dict[str, Any]:
+    def get_emr_response(self, context: Context) -> dict[str, Any]:
         """
         Make an API call with boto3 and get cluster-level details.
 
@@ -406,9 +409,18 @@ class EmrJobFlowSensor(EmrBaseSensor):
         :return: response
         """
         emr_client = self.get_hook().get_conn()
-
         self.log.info("Poking cluster %s", self.job_flow_id)
-        return emr_client.describe_cluster(ClusterId=self.job_flow_id)
+        response = emr_client.describe_cluster(ClusterId=self.job_flow_id)
+        log_uri = S3Hook.parse_s3_url(response["Cluster"]["LogUri"])
+        EmrLogsLink.persist(
+            context=context,
+            operator=self,
+            region_name=self.get_hook().conn_region_name,
+            aws_partition=self.get_hook().conn_partition,
+            job_flow_id=self.job_flow_id,
+            log_uri="/".join(log_uri),
+        )
+        return response
 
     @staticmethod
     def state_from_response(response: dict[str, Any]) -> str:
@@ -476,7 +488,7 @@ class EmrStepSensor(EmrBaseSensor):
         self.target_states = target_states or ["COMPLETED"]
         self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
 
-    def get_emr_response(self) -> dict[str, Any]:
+    def get_emr_response(self, context: Context) -> dict[str, Any]:
         """
         Make an API call with boto3 and get details about the cluster step.
 
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 3f7b793a1c..230bc135ae 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -558,6 +558,7 @@ extra-links:
   - airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
   - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
   - airflow.providers.amazon.aws.links.emr.EmrClusterLink
+  - airflow.providers.amazon.aws.links.emr.EmrLogsLink
   - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
 
 connection-types:
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index 48cfb9912f..ff231482da 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -71,6 +71,14 @@ class TestAwsS3Hook:
         parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt")
         assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
 
+    def test_parse_s3_url_s3a_style(self):
+        parsed = S3Hook.parse_s3_url("s3a://test/this/is/not/a-real-key.txt")
+        assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
+
+    def test_parse_s3_url_s3n_style(self):
+        parsed = S3Hook.parse_s3_url("s3n://test/this/is/not/a-real-key.txt")
+        assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
+
     def test_parse_s3_url_path_style(self):
         parsed = S3Hook.parse_s3_url("https://s3.us-west-2.amazonaws.com/DOC-EXAMPLE-BUCKET1/test.jpg")
         assert parsed == ("DOC-EXAMPLE-BUCKET1", "test.jpg"), "Incorrect parsing of the s3 url"
diff --git a/tests/providers/amazon/aws/sensors/test_emr_base.py b/tests/providers/amazon/aws/sensors/test_emr_base.py
index b0dfd66233..f6b4351833 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_base.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_base.py
@@ -21,6 +21,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor
+from airflow.utils.context import Context
 
 TARGET_STATE = "TARGET_STATE"
 FAILED_STATE = "FAILED_STATE"
@@ -40,7 +41,7 @@ class EmrBaseSensorSubclass(EmrBaseSensor):
         self.failed_states = [FAILED_STATE]
         self.response = {}  # will be set in tests
 
-    def get_emr_response(self):
+    def get_emr_response(self, context: Context):
         return self.response
 
     @staticmethod
diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
index 87a80d6a01..a10e68abb8 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
@@ -199,6 +199,9 @@ class TestEmrJobFlowSensor:
         # Mock out the emr_client creator
         self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
 
+        # Mock context used in execute function
+        self.mock_ctx = MagicMock()
+
     def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self):
         self.mock_emr_client.describe_cluster.side_effect = [
             DESCRIBE_CLUSTER_STARTING_RETURN,
@@ -210,7 +213,7 @@ class TestEmrJobFlowSensor:
                 task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
             )
 
-            operator.execute(None)
+            operator.execute(self.mock_ctx)
 
             # make sure we called twice
             assert self.mock_emr_client.describe_cluster.call_count == 3
@@ -230,7 +233,7 @@ class TestEmrJobFlowSensor:
             )
 
             with pytest.raises(AirflowException):
-                operator.execute(None)
+                operator.execute(self.mock_ctx)
 
                 # make sure we called twice
                 assert self.mock_emr_client.describe_cluster.call_count == 2
@@ -256,7 +259,7 @@ class TestEmrJobFlowSensor:
                 target_states=["RUNNING", "WAITING"],
             )
 
-            operator.execute(None)
+            operator.execute(self.mock_ctx)
 
             # make sure we called twice
             assert self.mock_emr_client.describe_cluster.call_count == 3