You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2023/12/05 10:34:05 UTC

(airflow) 20/34: [AIP-44] Introduce Pydantic model for LogTemplate (#36004)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 08188ed880b05d6e74f38c18577dc8d4e68f526e
Author: mhenc <mh...@google.com>
AuthorDate: Fri Dec 1 16:20:32 2023 +0100

    [AIP-44] Introduce Pydantic model for LogTemplate (#36004)
    
    (cherry picked from commit c26aa12bcc91429f0c2dd53e066a8480c00f822c)
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  1 +
 airflow/models/dagrun.py                           | 18 ++++++++++---
 airflow/serialization/enums.py                     |  1 +
 airflow/serialization/pydantic/dag_run.py          |  1 +
 airflow/serialization/pydantic/tasklog.py          | 30 ++++++++++++++++++++++
 airflow/serialization/serialized_objects.py        |  7 ++++-
 tests/serialization/test_serialized_objects.py     |  8 ++++++
 7 files changed, 61 insertions(+), 5 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index f451659cc0..7f1629affa 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -76,6 +76,7 @@ def _initialize_map() -> dict[str, Callable]:
         DagRun.get_previous_dagrun,
         DagRun.get_previous_scheduled_dagrun,
         DagRun.fetch_task_instance,
+        DagRun._get_log_template,
         SerializedDagModel.get_serialized_dag,
         TaskInstance._check_and_change_state_before_execution,
         TaskInstance.get_task_instance,
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index b7d9b05e82..a611155dda 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -78,6 +78,7 @@ if TYPE_CHECKING:
     from airflow.models.operator import Operator
     from airflow.serialization.pydantic.dag_run import DagRunPydantic
     from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
+    from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
     from airflow.typing_compat import Literal
     from airflow.utils.types import ArgNotSet
 
@@ -1460,14 +1461,23 @@ class DagRun(Base, LoggingMixin):
         return count
 
     @provide_session
-    def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate:
-        if self.log_template_id is None:  # DagRun created before LogTemplate introduction.
+    def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate | LogTemplatePydantic:
+        return DagRun._get_log_template(log_template_id=self.log_template_id, session=session)
+
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def _get_log_template(
+        log_template_id: int | None, session: Session = NEW_SESSION
+    ) -> LogTemplate | LogTemplatePydantic:
+        template: LogTemplate | None
+        if log_template_id is None:  # DagRun created before LogTemplate introduction.
             template = session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))
         else:
-            template = session.get(LogTemplate, self.log_template_id)
+            template = session.get(LogTemplate, log_template_id)
         if template is None:
             raise AirflowException(
-                f"No log_template entry found for ID {self.log_template_id!r}. "
+                f"No log_template entry found for ID {log_template_id!r}. "
                 f"Please make sure you set up the metadatabase correctly."
             )
         return template
diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py
index 744aeb4088..4f95c849c8 100644
--- a/airflow/serialization/enums.py
+++ b/airflow/serialization/enums.py
@@ -56,5 +56,6 @@ class DagAttributeTypes(str, Enum):
     DAG_RUN = "dag_run"
     DAG_MODEL = "dag_model"
     DATA_SET = "data_set"
+    LOG_TEMPLATE = "log_template"
     CONNECTION = "connection"
     ARG_NOT_SET = "arg_not_set"
diff --git a/airflow/serialization/pydantic/dag_run.py b/airflow/serialization/pydantic/dag_run.py
index cd0886ecaf..8faabc5ee4 100644
--- a/airflow/serialization/pydantic/dag_run.py
+++ b/airflow/serialization/pydantic/dag_run.py
@@ -55,6 +55,7 @@ class DagRunPydantic(BaseModelPydantic):
     updated_at: Optional[datetime]
     dag: Optional[PydanticDag]
     consumed_dataset_events: List[DatasetEventPydantic]  # noqa
+    log_template_id: Optional[int]
 
     model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
 
diff --git a/airflow/serialization/pydantic/tasklog.py b/airflow/serialization/pydantic/tasklog.py
new file mode 100644
index 0000000000..a23204400c
--- /dev/null
+++ b/airflow/serialization/pydantic/tasklog.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from datetime import datetime
+
+from pydantic import BaseModel as BaseModelPydantic, ConfigDict
+
+
+class LogTemplatePydantic(BaseModelPydantic):
+    """Serializable version of the LogTemplate ORM SqlAlchemyModel used by internal API."""
+
+    id: int
+    filename: str
+    elasticsearch_id: str
+    created_at: datetime
+
+    model_config = ConfigDict(from_attributes=True)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 6f0e88cae2..48aa595933 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -47,6 +47,7 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input,
 from airflow.models.mappedoperator import MappedOperator
 from airflow.models.param import Param, ParamsDict
 from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.tasklog import LogTemplate
 from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
 from airflow.providers_manager import ProvidersManager
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
@@ -57,6 +58,7 @@ from airflow.serialization.pydantic.dag_run import DagRunPydantic
 from airflow.serialization.pydantic.dataset import DatasetPydantic
 from airflow.serialization.pydantic.job import JobPydantic
 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
+from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
 from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.docs import get_docs_url
@@ -514,7 +516,8 @@ class BaseSerialization:
                 return cls._encode(_pydantic_model_dump(DatasetPydantic, var), type_=DAT.DATA_SET)
             elif isinstance(var, DagModel):
                 return cls._encode(_pydantic_model_dump(DagModelPydantic, var), type_=DAT.DAG_MODEL)
-
+            elif isinstance(var, LogTemplate):
+                return cls._encode(_pydantic_model_dump(LogTemplatePydantic, var), type_=DAT.LOG_TEMPLATE)
             else:
                 return cls.default_serialization(strict, var)
         elif isinstance(var, ArgNotSet):
@@ -596,6 +599,8 @@ class BaseSerialization:
                 return DagModelPydantic.parse_obj(var)
             elif type_ == DAT.DATA_SET:
                 return DatasetPydantic.parse_obj(var)
+            elif type_ == DAT.LOG_TEMPLATE:
+                return LogTemplatePydantic.parse_obj(var)
         elif type_ == DAT.ARG_NOT_SET:
             return NOTSET
         else:
diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py
index 0af29e8ebc..a40e0d01ea 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -33,6 +33,7 @@ from airflow.models.dag import DAG, DagModel
 from airflow.models.dagrun import DagRun
 from airflow.models.param import Param
 from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.tasklog import LogTemplate
 from airflow.models.xcom_arg import XComArg
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
@@ -41,6 +42,7 @@ from airflow.serialization.pydantic.dag import DagModelPydantic
 from airflow.serialization.pydantic.dag_run import DagRunPydantic
 from airflow.serialization.pydantic.job import JobPydantic
 from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
+from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
 from airflow.settings import _ENABLE_AIP_44
 from airflow.utils.operator_resources import Resources
 from airflow.utils.state import DagRunState, State
@@ -278,6 +280,12 @@ def test_backcompat_deserialize_connection(conn_uri):
             DAT.DAG_MODEL,
             lambda a, b: a.fileloc == b.fileloc and a.schedule_interval == b.schedule_interval,
         ),
+        (
+            LogTemplate(id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now()),
+            LogTemplatePydantic,
+            DAT.LOG_TEMPLATE,
+            lambda a, b: a.id == b.id and a.filename == b.filename and equal_time(a.created_at, b.created_at),
+        ),
     ],
 )
 def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp_func):