You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/09/07 08:16:32 UTC

[airflow] branch main updated: Change TaskInstance and TaskReschedule PK from execution_date to run_id (#17719)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 944dcfb  Change TaskInstance and TaskReschedule PK from execution_date to run_id (#17719)
944dcfb is described below

commit 944dcfbb918050274fd3a1cc51d8fdf460ea2429
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Tue Sep 7 09:16:06 2021 +0100

    Change TaskInstance and TaskReschedule PK from execution_date to run_id (#17719)
    
    Since TaskReschedule had an existing FK to TaskInstance we had to move
    change both of these at the same time.
    
    This puts an explicit FK constraint between TaskInstance and DagRun,
    meaning that we can remove a lot of "find TIs without DagRun" code in
    the scheduler too, as that is no longer a possible situation.
    
    Since there is now an explicit foreign key between TaskInstance and
    DagRun, we can remove a lot of the "cleanup" code in the scheduler that
    was dealing with this.
    
    This change was made as part of AIP-39
    
    Co-authored-by: Tzu-ping Chung <tp...@astronomer.io>
---
 UPDATING.md                                        |  10 +
 airflow/api/common/experimental/mark_tasks.py      |   5 +-
 airflow/api_connexion/endpoints/log_endpoint.py    |  18 +-
 .../endpoints/task_instance_endpoint.py            |  39 +-
 .../api_connexion/schemas/task_instance_schema.py  |  10 +-
 airflow/cli/commands/dag_command.py                |   4 +-
 airflow/cli/commands/kubernetes_command.py         |   6 +-
 airflow/cli/commands/task_command.py               | 106 ++--
 airflow/config_templates/config.yml                |   8 -
 airflow/config_templates/default_airflow.cfg       |   4 -
 airflow/dag_processing/processor.py                |  16 +-
 airflow/executors/kubernetes_executor.py           |  41 +-
 airflow/jobs/backfill_job.py                       |  46 +-
 airflow/jobs/local_task_job.py                     |   2 +-
 airflow/jobs/scheduler_job.py                      | 127 +---
 airflow/kubernetes/kubernetes_helper_functions.py  |  26 +-
 airflow/kubernetes/pod_generator.py                |  40 +-
 .../7b2661a43ba3_taskinstance_keyed_to_dagrun.py   | 287 +++++++++
 airflow/models/baseoperator.py                     |  35 +-
 airflow/models/dag.py                              |  41 +-
 airflow/models/dagrun.py                           |  56 +-
 airflow/models/skipmixin.py                        |  85 ++-
 airflow/models/taskinstance.py                     | 292 +++++----
 airflow/models/taskreschedule.py                   |  26 +-
 .../providers/google/cloud/operators/bigquery.py   |   9 +-
 airflow/sensors/base.py                            |   2 +-
 airflow/sensors/smart_sensor.py                    |  12 +-
 airflow/sentry.py                                  |   8 +-
 airflow/ti_deps/dep_context.py                     |  22 +-
 airflow/ti_deps/deps/dagrun_exists_dep.py          |  28 +-
 airflow/ti_deps/deps/dagrun_id_dep.py              |  11 +-
 airflow/ti_deps/deps/not_previously_skipped_dep.py |   2 +-
 airflow/ti_deps/deps/runnable_exec_date_dep.py     |  13 +-
 airflow/ti_deps/deps/trigger_rule_dep.py           |   2 +-
 airflow/utils/callback_requests.py                 |   7 +-
 airflow/www/auth.py                                |   2 +
 airflow/www/decorators.py                          |   1 +
 airflow/www/utils.py                               |  21 +-
 airflow/www/views.py                               | 191 +++---
 docs/apache-airflow/concepts/scheduler.rst         |  11 -
 docs/apache-airflow/logging-monitoring/metrics.rst |   1 -
 docs/apache-airflow/migrations-ref.rst             |   4 +-
 kubernetes_tests/test_kubernetes_pod_operator.py   |   6 +-
 .../test_kubernetes_pod_operator_backcompat.py     |   6 +-
 tests/api/common/experimental/test_delete_dag.py   |  36 +-
 tests/api/common/experimental/test_mark_tasks.py   |  41 +-
 tests/api_connexion/conftest.py                    |   8 +
 tests/api_connexion/endpoints/test_dag_endpoint.py |   2 +
 .../endpoints/test_event_log_endpoint.py           | 194 +++---
 tests/api_connexion/endpoints/test_log_endpoint.py |  54 +-
 .../endpoints/test_task_instance_endpoint.py       | 105 +---
 tests/api_connexion/schemas/test_dag_run_schema.py |   9 +-
 .../api_connexion/schemas/test_event_log_schema.py |  74 +--
 .../schemas/test_task_instance_schema.py           |  60 +-
 tests/cli/commands/test_dag_command.py             |   4 +-
 tests/cli/commands/test_task_command.py            | 195 +-----
 tests/conftest.py                                  |  93 ++-
 tests/core/test_core.py                            |  79 +--
 tests/core/test_sentry.py                          |  55 +-
 tests/dag_processing/test_manager.py               |  34 +-
 tests/dag_processing/test_processor.py             |  89 +--
 tests/executors/test_base_executor.py              |  87 ++-
 tests/executors/test_celery_executor.py            |  21 +-
 tests/executors/test_kubernetes_executor.py        |  34 +-
 tests/jobs/test_backfill_job.py                    |  73 +--
 tests/jobs/test_local_task_job.py                  |  58 +-
 tests/jobs/test_scheduler_job.py                   | 487 ++++-----------
 tests/jobs/test_triggerer_job.py                   |  22 +-
 tests/lineage/test_lineage.py                      |  40 +-
 tests/models/test_baseoperator.py                  |  10 +
 tests/models/test_cleartasks.py                    | 190 +++---
 tests/models/test_dag.py                           |  96 +--
 tests/models/test_dagrun.py                        |  48 +-
 tests/models/test_renderedtifields.py              |  64 +-
 tests/models/test_skipmixin.py                     |   3 +-
 tests/models/test_taskinstance.py                  | 685 ++++++++++-----------
 tests/models/test_trigger.py                       |  27 +-
 tests/operators/test_latest_only_operator.py       |   3 +-
 tests/operators/test_python.py                     | 115 +---
 tests/operators/test_subdag_operator.py            |  88 +--
 .../amazon/aws/log/test_cloudwatch_task_handler.py |  10 +-
 .../amazon/aws/log/test_s3_task_handler.py         |  10 +-
 .../providers/amazon/aws/operators/test_athena.py  |  12 +-
 .../amazon/aws/operators/test_datasync.py          |  60 +-
 .../aws/operators/test_dms_describe_tasks.py       |   9 +-
 .../amazon/aws/operators/test_emr_add_steps.py     |  61 +-
 .../aws/operators/test_emr_create_job_flow.py      |  12 +-
 tests/providers/amazon/aws/sensors/test_s3_key.py  |   7 +-
 .../amazon/aws/transfers/test_mongo_to_s3.py       |   7 +-
 .../amazon/aws/transfers/test_s3_to_sftp.py        |  13 +-
 .../amazon/aws/transfers/test_sftp_to_s3.py        |   6 +-
 .../providers/apache/druid/operators/test_druid.py |  98 ++-
 tests/providers/apache/hive/operators/test_hive.py |   6 +-
 .../apache/kylin/operators/test_kylin_cube.py      |   5 +-
 .../apache/spark/operators/test_spark_submit.py    |   5 +-
 .../kubernetes/operators/test_kubernetes_pod.py    |   8 +-
 .../elasticsearch/log/test_es_task_handler.py      | 156 ++---
 .../google/cloud/log/test_gcs_task_handler.py      |  75 ++-
 .../cloud/log/test_stackdriver_task_handler.py     | 165 ++---
 .../google/cloud/operators/test_bigquery.py        | 150 +++--
 .../google/cloud/operators/test_cloud_build.py     |  48 +-
 .../test_cloud_storage_transfer_service.py         | 208 +++----
 .../google/cloud/operators/test_compute.py         |  78 ++-
 .../google/cloud/operators/test_dataproc.py        | 424 ++++++-------
 .../google/cloud/operators/test_mlengine.py        |  44 +-
 tests/providers/http/sensors/test_http.py          |  56 +-
 .../microsoft/azure/log/test_wasb_task_handler.py  |  51 +-
 .../microsoft/azure/operators/test_adx.py          |  24 +-
 tests/providers/qubole/operators/test_qubole.py    |  87 ++-
 tests/providers/sftp/operators/test_sftp.py        | 417 ++++++-------
 tests/providers/ssh/operators/test_ssh.py          | 124 ++--
 tests/sensors/test_base.py                         | 179 +++---
 tests/sensors/test_external_task_sensor.py         | 189 +++---
 tests/sensors/test_smart_sensor_operator.py        |  11 +-
 tests/serialization/test_dag_serialization.py      |  28 +-
 tests/test_utils/mock_executor.py                  |   5 +-
 tests/test_utils/mock_operators.py                 |  12 +-
 tests/ti_deps/deps/test_dagrun_exists_dep.py       |   3 +-
 tests/ti_deps/deps/test_dagrun_id_dep.py           |   7 -
 .../deps/test_not_previously_skipped_dep.py        | 177 +++---
 tests/ti_deps/deps/test_ready_to_reschedule_dep.py |   4 +-
 tests/ti_deps/deps/test_runnable_exec_date_dep.py  |  70 ++-
 tests/ti_deps/deps/test_trigger_rule_dep.py        | 160 +++--
 tests/utils/log/test_log_reader.py                 | 108 ++--
 tests/utils/test_dot_renderer.py                   |  87 ++-
 tests/utils/test_helpers.py                        |  40 +-
 tests/utils/test_log_handlers.py                   |  61 +-
 .../test_task_handler_with_custom_formatter.py     |  14 +-
 tests/www/views/test_views_dagrun.py               |  16 +-
 tests/www/views/test_views_extra_links.py          |  34 +-
 tests/www/views/test_views_log.py                  |  75 ++-
 tests/www/views/test_views_rendered.py             | 109 ++--
 132 files changed, 4284 insertions(+), 4513 deletions(-)

diff --git a/UPDATING.md b/UPDATING.md
index e5a08cd..7942ba4 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -228,6 +228,16 @@ Now that the DAG parser syncs DAG permissions there is no longer a need for manu
 
 In addition, the `/refresh` and `/refresh_all` webserver endpoints have also been removed.
 
+### TaskInstances now *require* a DagRun
+
+Under normal operation every TaskInstance row in the database would have DagRun row too, but it was possible to manually delete the DagRun and Airflow would still schedule the TaskInstances.
+
+In Airflow 2.2 we have changed this and now there is a database-level foreign key constraint ensuring that every TaskInstance has a DagRun row.
+
+Before updating to this 2.2 release you will have to manually resolve any inconsistencies (add back DagRun rows, or delete TaskInstances) if you have any "dangling" TaskInstance" rows.
+
+As part of this change the `clean_tis_without_dagrun_interval` config option under `[scheduler]` section has been removed and has no effect.
+
 ## Airflow 2.1.3
 
 No breaking changes.
diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py
index 08b7363..945a9cc 100644
--- a/airflow/api/common/experimental/mark_tasks.py
+++ b/airflow/api/common/experimental/mark_tasks.py
@@ -21,6 +21,7 @@ import datetime
 from typing import Iterable
 
 from sqlalchemy import or_
+from sqlalchemy.orm import contains_eager
 
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagrun import DagRun
@@ -148,12 +149,14 @@ def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
     """Get all tasks of the main dag that will be affected by a state change"""
     qry_dag = (
         session.query(TaskInstance)
+        .join(TaskInstance.dag_run)
         .filter(
             TaskInstance.dag_id == dag.dag_id,
-            TaskInstance.execution_date.in_(confirmed_dates),
+            DagRun.execution_date.in_(confirmed_dates),
             TaskInstance.task_id.in_(task_ids),
         )
         .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
+        .options(contains_eager(TaskInstance.dag_run))
     )
     return qry_dag
 
diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py
index 5fcff8e..ddd6cfc 100644
--- a/airflow/api_connexion/endpoints/log_endpoint.py
+++ b/airflow/api_connexion/endpoints/log_endpoint.py
@@ -18,12 +18,13 @@
 from flask import Response, current_app, request
 from itsdangerous.exc import BadSignature
 from itsdangerous.url_safe import URLSafeSerializer
+from sqlalchemy.orm import eagerload
 
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import BadRequest, NotFound
 from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema
 from airflow.exceptions import TaskNotFound
-from airflow.models import DagRun
+from airflow.models import TaskInstance
 from airflow.security import permissions
 from airflow.utils.log.log_reader import TaskLogReader
 from airflow.utils.session import provide_session
@@ -60,15 +61,16 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=
     if not task_log_reader.supports_read:
         raise BadRequest("Task log handler does not support read logs.")
 
-    query = session.query(DagRun).filter(DagRun.dag_id == dag_id)
-    dag_run = query.filter(DagRun.run_id == dag_run_id).first()
-    if not dag_run:
-        raise NotFound("DAG Run not found")
-
-    ti = dag_run.get_task_instance(task_id, session)
+    ti = (
+        session.query(TaskInstance)
+        .filter(TaskInstance.task_id == task_id, TaskInstance.run_id == dag_run_id)
+        .join(TaskInstance.dag_run)
+        .options(eagerload(TaskInstance.dag_run))
+        .one_or_none()
+    )
     if ti is None:
         metadata['end_of_log'] = True
-        raise BadRequest(detail="Task instance did not exist in the DB")
+        raise NotFound(title="TaskInstance not found")
 
     dag = current_app.dag_bag.get_dag(dag_id)
     if dag:
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index e2a6ce9..361d29e 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -19,6 +19,7 @@ from typing import Any, List, Optional, Tuple
 from flask import current_app, request
 from marshmallow import ValidationError
 from sqlalchemy import and_, func
+from sqlalchemy.orm import eagerload
 
 from airflow.api_connexion import security
 from airflow.api_connexion.exceptions import BadRequest, NotFound
@@ -54,15 +55,14 @@ def get_task_instance(dag_id: str, dag_run_id: str, task_id: str, session=None):
     """Get task instance"""
     query = (
         session.query(TI)
-        .filter(TI.dag_id == dag_id)
-        .join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date))
-        .filter(DR.run_id == dag_run_id)
-        .filter(TI.task_id == task_id)
+        .filter(TI.dag_id == dag_id, DR.run_id == dag_run_id, TI.task_id == task_id)
+        .join(TI.dag_run)
+        .options(eagerload(TI.dag_run))
         .outerjoin(
             SlaMiss,
             and_(
                 SlaMiss.dag_id == TI.dag_id,
-                SlaMiss.execution_date == TI.execution_date,
+                SlaMiss.execution_date == DR.execution_date,
                 SlaMiss.task_id == TI.task_id,
             ),
         )
@@ -127,13 +127,12 @@ def get_task_instances(
     session=None,
 ):
     """Get list of task instances."""
-    base_query = session.query(TI)
+    base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run))
 
     if dag_id != "~":
         base_query = base_query.filter(TI.dag_id == dag_id)
     if dag_run_id != "~":
-        base_query = base_query.join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date))
-        base_query = base_query.filter(DR.run_id == dag_run_id)
+        base_query = base_query.filter(TI.run_id == dag_run_id)
     base_query = _apply_range_filter(
         base_query,
         key=DR.execution_date,
@@ -156,7 +155,7 @@ def get_task_instances(
         and_(
             SlaMiss.dag_id == TI.dag_id,
             SlaMiss.task_id == TI.task_id,
-            SlaMiss.execution_date == TI.execution_date,
+            SlaMiss.execution_date == DR.execution_date,
         ),
         isouter=True,
     )
@@ -183,12 +182,12 @@ def get_task_instances_batch(session=None):
         data = task_instance_batch_form.load(body)
     except ValidationError as err:
         raise BadRequest(detail=str(err.messages))
-    base_query = session.query(TI)
+    base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run))
 
     base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
     base_query = _apply_range_filter(
         base_query,
-        key=TI.execution_date,
+        key=DR.execution_date,
         value_range=(data["execution_date_gte"], data["execution_date_lte"]),
     )
     base_query = _apply_range_filter(
@@ -214,7 +213,7 @@ def get_task_instances_batch(session=None):
         and_(
             SlaMiss.dag_id == TI.dag_id,
             SlaMiss.task_id == TI.task_id,
-            SlaMiss.execution_date == TI.execution_date,
+            SlaMiss.execution_date == DR.execution_date,
         ),
         isouter=True,
     )
@@ -254,9 +253,7 @@ def post_clear_task_instances(dag_id: str, session=None):
         clear_task_instances(
             task_instances.all(), session, dag=dag, dag_run_state=State.RUNNING if reset_dag_runs else False
         )
-    task_instances = task_instances.join(
-        DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)
-    ).add_column(DR.run_id)
+    task_instances = task_instances.join(TI.dag_run).options(eagerload(TI.dag_run))
     return task_instance_reference_collection_schema.dump(
         TaskInstanceReferenceCollection(task_instances=task_instances.all())
     )
@@ -303,14 +300,6 @@ def post_set_task_instances_state(dag_id, session):
         future=data["include_future"],
         past=data["include_past"],
         commit=not data["dry_run"],
+        session=session,
     )
-    execution_dates = {ti.execution_date for ti in tis}
-    execution_date_to_run_id_map = dict(
-        session.query(DR.execution_date, DR.run_id).filter(
-            DR.dag_id == dag_id, DR.execution_date.in_(execution_dates)
-        )
-    )
-    tis_with_run_id = [(ti, execution_date_to_run_id_map.get(ti.execution_date)) for ti in tis]
-    return task_instance_reference_collection_schema.dump(
-        TaskInstanceReferenceCollection(task_instances=tis_with_run_id)
-    )
+    return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis))
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index 95fc475..89ae9a6 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -134,18 +134,10 @@ class TaskInstanceReferenceSchema(Schema):
     """Schema for the task instance reference schema"""
 
     task_id = fields.Str()
-    dag_run_id = fields.Str()
+    run_id = fields.Str(data_key="dag_run_id")
     dag_id = fields.Str()
     execution_date = fields.DateTime()
 
-    def get_attribute(self, obj, attr, default):
-        """Overwritten marshmallow function"""
-        task_instance_attr = ['task_id', 'execution_date', 'dag_id']
-        if attr in task_instance_attr:
-            obj = obj[0]  # As object is a tuple of task_instance and dag_run_id
-            return get_value(obj, attr, default)
-        return obj[1]
-
 
 class TaskInstanceReferenceCollection(NamedTuple):
     """List of objects with metadata about taskinstance and dag_run_id"""
diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py
index dfd8c3c..fffcf68 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -89,9 +89,11 @@ def dag_backfill(args, dag=None):
 
     if args.dry_run:
         print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
+        dr = DagRun(dag.dag_id, execution_date=args.start_date)
         for task in dag.tasks:
             print(f"Task {task.task_id}")
-            ti = TaskInstance(task, args.start_date)
+            ti = TaskInstance(task, run_id=None)
+            ti.dag_run = dr
             ti.dry_run()
     else:
         if args.reset_dagruns:
diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py
index 2660dae..d7481e4 100644
--- a/airflow/cli/commands/kubernetes_command.py
+++ b/airflow/cli/commands/kubernetes_command.py
@@ -26,7 +26,7 @@ from airflow.executors.kubernetes_executor import KubeConfig, create_pod_id
 from airflow.kubernetes import pod_generator
 from airflow.kubernetes.kube_client import get_kube_client
 from airflow.kubernetes.pod_generator import PodGenerator
-from airflow.models import TaskInstance
+from airflow.models import DagRun, TaskInstance
 from airflow.settings import pod_mutation_hook
 from airflow.utils import cli as cli_utils, yaml
 from airflow.utils.cli import get_dag
@@ -38,9 +38,11 @@ def generate_pod_yaml(args):
     execution_date = args.execution_date
     dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
     yaml_output_path = args.output_path
+    dr = DagRun(dag.dag_id, execution_date=execution_date)
     kube_config = KubeConfig()
     for task in dag.tasks:
-        ti = TaskInstance(task, execution_date)
+        ti = TaskInstance(task, None)
+        ti.dag_run = dr
         pod = PodGenerator.construct_pod(
             dag_id=args.dag_id,
             task_id=ti.task_id,
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 4dd0283..91a9374 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -21,15 +21,16 @@ import json
 import logging
 import os
 import textwrap
-from contextlib import contextmanager, redirect_stderr, redirect_stdout
+from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
 from typing import List
 
 from pendulum.parsing.exceptions import ParserError
+from sqlalchemy.orm.exc import NoResultFound
 
 from airflow import settings
 from airflow.cli.simple_table import AirflowConsole
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, DagRunNotFound
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.models import DagPickle, TaskInstance
@@ -51,18 +52,43 @@ from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session, provide_session
 
 
-def _get_ti(task, exec_date_or_run_id):
+def _get_dag_run(dag, exec_date_or_run_id, create_if_necssary, session):
+    dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
+    if dag_run:
+        return dag_run
+
+    execution_date = None
+    with suppress(ParserError, TypeError):
+        execution_date = timezone.parse(exec_date_or_run_id)
+
+    if create_if_necssary and not execution_date:
+        return DagRun(dag_id=dag.dag_id, run_id=exec_date_or_run_id)
+    try:
+        return (
+            session.query(DagRun)
+            .filter(
+                DagRun.dag_id == dag.dag_id,
+                DagRun.execution_date == execution_date,
+            )
+            .one()
+        )
+    except NoResultFound:
+        if create_if_necssary:
+            return DagRun(dag.dag_id, execution_date=execution_date)
+        raise DagRunNotFound(
+            f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
+        ) from None
+
+
+@provide_session
+def _get_ti(task, exec_date_or_run_id, create_if_necssary=False, session=None):
     """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
-    dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id)
-    if not dag_run:
-        try:
-            execution_date = timezone.parse(exec_date_or_run_id)
-            ti = TaskInstance(task, execution_date)
-            ti.refresh_from_db()
-            return ti
-        except (ParserError, TypeError):
-            raise AirflowException(f"DagRun with run_id: {exec_date_or_run_id} not found")
+    dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necssary, session)
+
     ti = dag_run.get_task_instance(task.task_id)
+    if not ti and create_if_necssary:
+        ti = TaskInstance(task, run_id=None)
+        ti.dag_run = dag_run
     ti.refresh_from_task(task)
     return ti
 
@@ -75,11 +101,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
     - as raw task
     - by executor
     """
-    if args.local and args.raw:
-        raise AirflowException(
-            "Option --raw and --local are mutually exclusive. "
-            "Please remove one option to execute the command."
-        )
     if args.local:
         _run_task_by_local_task_job(args, ti)
     elif args.raw:
@@ -155,17 +176,6 @@ RAW_TASK_UNSUPPORTED_OPTION = [
 
 def _run_raw_task(args, ti: TaskInstance) -> None:
     """Runs the main task handling code"""
-    unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
-
-    if unsupported_options:
-        raise AirflowException(
-            "Option --raw does not work with some of the other options on this command. You "
-            "can't use --raw option and the following options: {}. You provided the option {}. "
-            "Delete it to execute the command".format(
-                ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION),
-                ", ".join(f"--{o}" for o in unsupported_options),
-            )
-        )
     ti._run_raw_task(
         mark_success=args.mark_success,
         job_id=args.job_id,
@@ -213,6 +223,27 @@ def _capture_task_logs(ti):
 def task_run(args, dag=None):
     """Runs a single task instance"""
     # Load custom airflow config
+
+    if args.local and args.raw:
+        raise AirflowException(
+            "Option --raw and --local are mutually exclusive. "
+            "Please remove one option to execute the command."
+        )
+
+    if args.raw:
+        unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
+
+        if unsupported_options:
+            raise AirflowException(
+                "Option --raw does not work with some of the other options on this command. You "
+                "can't use --raw option and the following options: {}. You provided the option {}. "
+                "Delete it to execute the command".format(
+                    ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION),
+                    ", ".join(f"--{o}" for o in unsupported_options),
+                )
+            )
+    if dag and args.pickle:
+        raise AirflowException("You cannot use the --pickle option when using DAG.cli() method.")
     if args.cfg_path:
         with open(args.cfg_path) as conf_file:
             conf_dict = json.load(conf_file)
@@ -231,9 +262,7 @@ def task_run(args, dag=None):
     # processing hundreds of simultaneous tasks.
     settings.configure_orm(disable_connection_pool=True)
 
-    if dag and args.pickle:
-        raise AirflowException("You cannot use the --pickle option when using DAG.cli() method.")
-    elif args.pickle:
+    if args.pickle:
         print(f'Loading pickle id: {args.pickle}')
         dag = get_dag_by_pickle(args.pickle)
     elif not dag:
@@ -359,14 +388,17 @@ def task_states_for_dag_run(args, session=None):
             raise AirflowException(f"Error parsing the supplied execution_date. Error: {str(err)}")
 
     if dag_run is None:
-        raise AirflowException("DagRun does not exist.")
-    tis = dag_run.get_task_instances()
+        raise DagRunNotFound(
+            f"DagRun for {args.dag_id} with run_id or execution_date of {args.execution_date_or_run_id!r} "
+            "not found"
+        )
+
     AirflowConsole().print_as(
-        data=tis,
+        data=dag_run.task_instances,
         output=args.output,
         mapper=lambda ti: {
             "dag_id": ti.dag_id,
-            "execution_date": ti.execution_date.isoformat(),
+            "execution_date": dag_run.execution_date.isoformat(),
             "task_id": ti.task_id,
             "state": ti.state,
             "start_date": ti.start_date.isoformat() if ti.start_date else "",
@@ -405,7 +437,7 @@ def task_test(args, dag=None):
     if args.task_params:
         passed_in_params = json.loads(args.task_params)
         task.params.update(passed_in_params)
-    ti = _get_ti(task, args.execution_date_or_run_id)
+    ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True)
 
     try:
         if args.dry_run:
@@ -431,7 +463,7 @@ def task_render(args):
     """Renders and displays templated fields for a given task"""
     dag = get_dag(args.subdir, args.dag_id)
     task = dag.get_task(task_id=args.task_id)
-    ti = _get_ti(task, args.execution_date_or_run_id)
+    ti = _get_ti(task, args.execution_date_or_run_id, create_if_necssary=True)
     ti.render_templates()
     for attr in task.__class__.template_fields:
         print(
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 95e1bcd..da35382 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1693,14 +1693,6 @@
       type: string
       example: ~
       default: "5"
-    - name: clean_tis_without_dagrun_interval
-      description: |
-        How often (in seconds) to check and tidy up 'running' TaskInstancess
-        that no longer have a matching DagRun
-      version_added: 2.0.0
-      type: float
-      example: ~
-      default: "15.0"
     - name: scheduler_heartbeat_sec
       description: |
         The scheduler constantly tries to trigger new tasks (look at the
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 9bc1252..e36a6eb 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -845,10 +845,6 @@ tls_key =
 # listen (in seconds).
 job_heartbeat_sec = 5
 
-# How often (in seconds) to check and tidy up 'running' TaskInstancess
-# that no longer have a matching DagRun
-clean_tis_without_dagrun_interval = 15.0
-
 # The scheduler constantly tries to trigger new tasks (look at the
 # scheduler section in the docs for more information). This defines
 # how often the scheduler should run (in seconds).
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
index 55efd8e..ce6b955 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -24,7 +24,7 @@ import threading
 from contextlib import redirect_stderr, redirect_stdout, suppress
 from datetime import timedelta
 from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import List, Optional, Set, Tuple
+from typing import Iterator, List, Optional, Set, Tuple
 
 from setproctitle import setproctitle
 from sqlalchemy import func, or_
@@ -49,6 +49,7 @@ from airflow.utils.mixins import MultiprocessingStartMethodMixin
 from airflow.utils.session import provide_session
 from airflow.utils.state import State
 
+DR = models.DagRun
 TI = models.TaskInstance
 
 
@@ -378,7 +379,8 @@ class DagFileProcessor(LoggingMixin):
             return
 
         qry = (
-            session.query(TI.task_id, func.max(TI.execution_date).label('max_ti'))
+            session.query(TI.task_id, func.max(DR.execution_date).label('max_ti'))
+            .join(TI.dag_run)
             .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
             .filter(TI.dag_id == dag.dag_id)
             .filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
@@ -387,14 +389,14 @@ class DagFileProcessor(LoggingMixin):
             .subquery('sq')
         )
 
-        max_tis: List[TI] = (
+        max_tis: Iterator[TI] = (
             session.query(TI)
+            .join(TI.dag_run)
             .filter(
                 TI.dag_id == dag.dag_id,
                 TI.task_id == qry.c.task_id,
-                TI.execution_date == qry.c.max_ti,
+                DR.execution_date == qry.c.max_ti,
             )
-            .all()
         )
 
         ts = timezone.utcnow()
@@ -558,7 +560,7 @@ class DagFileProcessor(LoggingMixin):
     @provide_session
     def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
         dag = dagbag.dags[request.dag_id]
-        dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session)
+        dag_run = dag.get_dagrun(run_id=request.run_id, session=session)
         dag.handle_callback(
             dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
         )
@@ -570,7 +572,7 @@ class DagFileProcessor(LoggingMixin):
             if simple_ti.task_id in dag.task_ids:
                 task = dag.get_task(simple_ti.task_id)
                 if request.is_failure_callback:
-                    ti = TI(task, simple_ti.execution_date)
+                    ti = TI(task, run_id=simple_ti.run_id)
                     # TODO: Use simple_ti to improve performance here in the future
                     ti.refresh_from_db()
                     ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index 9165687..993787c 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -151,7 +151,8 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
             task_instance_related_annotations = {
                 'dag_id': annotations['dag_id'],
                 'task_id': annotations['task_id'],
-                'execution_date': annotations['execution_date'],
+                'execution_date': annotations.get('execution_date'),
+                'run_id': annotations.get('run_id'),
                 'try_number': annotations['try_number'],
             }
 
@@ -291,7 +292,7 @@ class AirflowKubernetesScheduler(LoggingMixin):
         """
         self.log.info('Kubernetes job is %s', str(next_job))
         key, command, kube_executor_config, pod_template_file = next_job
-        dag_id, task_id, execution_date, try_number = key
+        dag_id, task_id, run_id, try_number = key
 
         if command[0:3] != ["airflow", "tasks", "run"]:
             raise ValueError('The command must start with ["airflow", "tasks", "run"].')
@@ -311,7 +312,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
             task_id=task_id,
             kube_image=self.kube_config.kube_image,
             try_number=try_number,
-            date=execution_date,
+            date=None,
+            run_id=run_id,
             args=command,
             pod_override_object=kube_executor_config,
             base_worker_pod=base_worker_pod,
@@ -453,27 +455,34 @@ class KubernetesExecutor(BaseExecutor, LoggingMixin):
         for task in queued_tasks:
 
             self.log.debug("Checking task %s", task)
-            dict_string = "dag_id={},task_id={},execution_date={},airflow-worker={}".format(
+            dict_string = "dag_id={},task_id={},airflow-worker={}".format(
                 pod_generator.make_safe_label_value(task.dag_id),
                 pod_generator.make_safe_label_value(task.task_id),
-                pod_generator.datetime_to_label_safe_datestring(task.execution_date),
                 pod_generator.make_safe_label_value(str(self.scheduler_job_id)),
             )
 
             kwargs = dict(label_selector=dict_string)
             if self.kube_config.kube_client_request_args:
-                for key, value in self.kube_config.kube_client_request_args.items():
-                    kwargs[key] = value
+                kwargs.update(**self.kube_config.kube_client_request_args)
+
+            # Try run_id first
+            kwargs['label_selector'] += ',run_id=' + pod_generator.make_safe_label_value(task.run_id)
             pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs)
-            if not pod_list.items:
-                self.log.info(
-                    'TaskInstance: %s found in queued state but was not launched, rescheduling', task
-                )
-                session.query(TaskInstance).filter(
-                    TaskInstance.dag_id == task.dag_id,
-                    TaskInstance.task_id == task.task_id,
-                    TaskInstance.execution_date == task.execution_date,
-                ).update({TaskInstance.state: State.NONE})
+            if pod_list.items:
+                continue
+            # Fallback to old style of using execution_date
+            kwargs['label_selector'] = dict_string + ',exectuion_date={}'.format(
+                pod_generator.datetime_to_label_safe_datestring(task.execution_date)
+            )
+            pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs)
+            if pod_list.items:
+                continue
+            self.log.info('TaskInstance: %s found in queued state but was not launched, rescheduling', task)
+            session.query(TaskInstance).filter(
+                TaskInstance.dag_id == task.dag_id,
+                TaskInstance.task_id == task.task_id,
+                TaskInstance.run_id == task.run_id,
+            ).update({TaskInstance.state: State.NONE})
 
     def start(self) -> None:
         """Starts the executor"""
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 76d1100..e105eb5 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -22,7 +22,7 @@ from collections import OrderedDict
 from typing import Optional, Set
 
 import pendulum
-from sqlalchemy import and_
+from sqlalchemy.orm import eagerload
 from sqlalchemy.orm.session import Session, make_transient
 from tabulate import tabulate
 
@@ -335,7 +335,6 @@ class BackfillJob(BaseJob):
 
         # explicitly mark as backfill and running
         run.state = State.RUNNING
-        run.run_id = run.generate_run_id(DagRunType.BACKFILL_JOB, run_date)
         run.run_type = DagRunType.BACKFILL_JOB
         run.verify_integrity(session=session)
         return run
@@ -434,15 +433,12 @@ class BackfillJob(BaseJob):
             # determined deadlocked while they are actually
             # waiting for their upstream to finish
             @provide_session
-            def _per_task_process(key, ti, session=None):
+            def _per_task_process(key, ti: TaskInstance, session=None):
                 ti.refresh_from_db(lock_for_update=True, session=session)
 
                 task = self.dag.get_task(ti.task_id, include_subdags=True)
                 ti.task = task
 
-                ignore_depends_on_past = self.ignore_first_depends_on_past and ti.execution_date == (
-                    start_date or ti.start_date
-                )
                 self.log.debug("Task instance to run %s state %s", ti, ti.state)
 
                 # The task was already marked successful or skipped by a
@@ -487,6 +483,12 @@ class BackfillJob(BaseJob):
                             ti_status.running.pop(key)
                         return
 
+                if self.ignore_first_depends_on_past:
+                    dagrun = ti.get_dagrun(session=session)
+                    ignore_depends_on_past = dagrun.execution_date == (start_date or ti.start_date)
+                else:
+                    ignore_depends_on_past = False
+
                 backfill_context = DepContext(
                     deps=BACKFILL_QUEUED_DEPS,
                     ignore_depends_on_past=ignore_depends_on_past,
@@ -580,6 +582,7 @@ class BackfillJob(BaseJob):
                         num_running_task_instances_in_dag = DAG.get_num_task_instances(
                             self.dag_id,
                             states=self.STATES_COUNT_AS_RUNNING,
+                            session=session,
                         )
 
                         if num_running_task_instances_in_dag >= self.dag.max_active_tasks:
@@ -592,6 +595,7 @@ class BackfillJob(BaseJob):
                                 dag_id=self.dag_id,
                                 task_ids=[task.task_id],
                                 states=self.STATES_COUNT_AS_RUNNING,
+                                session=session,
                             )
 
                             if num_running_task_instances_in_task >= task.max_active_tis_per_dag:
@@ -645,17 +649,15 @@ class BackfillJob(BaseJob):
             # Sorting by execution date first
             sorted_ti_keys = sorted(
                 set_ti_keys,
-                key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, ti_key.task_id, ti_key.try_number),
+                key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, ti_key.task_id, ti_key.try_number),
             )
-            return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Execution date", "Try number"])
+            return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
 
         def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
             # Sorting by execution date first
-            sorted_tis = sorted(
-                set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, ti.task_id, ti.try_number)
-            )
-            tis_values = ((ti.dag_id, ti.task_id, ti.execution_date, ti.try_number) for ti in sorted_tis)
-            return tabulate(tis_values, headers=["DAG ID", "Task ID", "Execution date", "Try number"])
+            sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, ti.task_id, ti.try_number))
+            tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) for ti in sorted_tis)
+            return tabulate(tis_values, headers=["DAG ID", "Task ID", "Run ID", "Try number"])
 
         err = ''
         if ti_status.failed:
@@ -861,17 +863,13 @@ class BackfillJob(BaseJob):
         # also consider running as the state might not have changed in the db yet
         running_tis = self.executor.running
 
+        # Can't use an update here since it doesn't support joins.
         resettable_states = [State.SCHEDULED, State.QUEUED]
         if filter_by_dag_run is None:
             resettable_tis = (
                 session.query(TaskInstance)
-                .join(
-                    DagRun,
-                    and_(
-                        TaskInstance.dag_id == DagRun.dag_id,
-                        TaskInstance.execution_date == DagRun.execution_date,
-                    ),
-                )
+                .join(TaskInstance.dag_run)
+                .options(eagerload(TaskInstance.dag_run))
                 .filter(
                     DagRun.state == State.RUNNING,
                     DagRun.run_type != DagRunType.BACKFILL_JOB,
@@ -880,12 +878,8 @@ class BackfillJob(BaseJob):
             ).all()
         else:
             resettable_tis = filter_by_dag_run.get_task_instances(state=resettable_states, session=session)
-        tis_to_reset = []
-        # Can't use an update here since it doesn't support joins
-        for ti in resettable_tis:
-            if ti.key not in queued_tis and ti.key not in running_tis:
-                tis_to_reset.append(ti)
 
+        tis_to_reset = [ti for ti in resettable_tis if ti.key not in queued_tis and ti.key not in running_tis]
         if not tis_to_reset:
             return 0
 
@@ -910,7 +904,7 @@ class BackfillJob(BaseJob):
         reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query)
 
         task_instance_str = '\n\t'.join(repr(x) for x in reset_tis)
-        session.commit()
+        session.flush()
 
         self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str)
         return len(reset_tis)
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 203a7a8..7878216 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -227,7 +227,7 @@ class LocalTaskJob(BaseJob):
             dag_run = with_row_locks(
                 session.query(DagRun).filter_by(
                     dag_id=self.dag_id,
-                    execution_date=self.task_instance.execution_date,
+                    run_id=self.task_instance.run_id,
                 ),
                 session=session,
             ).one()
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index add07d9..d7239a8 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -30,7 +30,7 @@ from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
 
 from sqlalchemy import and_, func, not_, or_, tuple_
 from sqlalchemy.exc import OperationalError
-from sqlalchemy.orm import load_only, selectinload
+from sqlalchemy.orm import eagerload, load_only, selectinload
 from sqlalchemy.orm.session import Session, make_transient
 
 from airflow import models, settings
@@ -190,80 +190,6 @@ class SchedulerJob(BaseJob):
         )
 
     @provide_session
-    def _change_state_for_tis_without_dagrun(
-        self, old_states: List[TaskInstanceState], new_state: TaskInstanceState, session: Session = None
-    ) -> None:
-        """
-        For all DAG IDs in the DagBag, look for task instances in the
-        old_states and set them to new_state if the corresponding DagRun
-        does not exist or exists but is not in the running or queued state. This
-        normally should not happen, but it can if the state of DagRuns are
-        changed manually.
-
-        :param old_states: examine TaskInstances in this state
-        :type old_states: list[airflow.utils.state.State]
-        :param new_state: set TaskInstances to this state
-        :type new_state: airflow.utils.state.State
-        """
-        tis_changed = 0
-        query = (
-            session.query(models.TaskInstance)
-            .outerjoin(models.TaskInstance.dag_run)
-            .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids)))
-            .filter(models.TaskInstance.state.in_(old_states))
-            .filter(
-                or_(
-                    models.DagRun.state.notin_([State.RUNNING, State.QUEUED]),
-                    models.DagRun.state.is_(None),
-                )
-            )
-        )
-        # We need to do this for mysql as well because it can cause deadlocks
-        # as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516
-        if self.using_sqlite or self.using_mysql:
-            tis_to_change: List[TI] = with_row_locks(
-                query, of=TI, session=session, **skip_locked(session=session)
-            ).all()
-            for ti in tis_to_change:
-                ti.set_state(new_state, session=session)
-                tis_changed += 1
-        else:
-            subq = query.subquery()
-            current_time = timezone.utcnow()
-            ti_prop_update = {
-                models.TaskInstance.state: new_state,
-                models.TaskInstance.start_date: current_time,
-            }
-
-            # Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped'
-            if new_state in State.finished:
-                ti_prop_update.update(
-                    {
-                        models.TaskInstance.end_date: current_time,
-                        models.TaskInstance.duration: 0,
-                    }
-                )
-
-            tis_changed = (
-                session.query(models.TaskInstance)
-                .filter(
-                    models.TaskInstance.dag_id == subq.c.dag_id,
-                    models.TaskInstance.task_id == subq.c.task_id,
-                    models.TaskInstance.execution_date == subq.c.execution_date,
-                )
-                .update(ti_prop_update, synchronize_session=False)
-            )
-
-        if tis_changed > 0:
-            session.flush()
-            self.log.warning(
-                "Set %s task instances to state=%s as their associated DagRun was not in RUNNING state",
-                tis_changed,
-                new_state,
-            )
-            Stats.gauge('scheduler.tasks.without_dagrun', tis_changed)
-
-    @provide_session
     def __get_concurrency_maps(
         self, states: List[TaskInstanceState], session: Session = None
     ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]:
@@ -320,14 +246,14 @@ class SchedulerJob(BaseJob):
         # and the dag is not paused
         query = (
             session.query(TI)
-            .outerjoin(TI.dag_run)
-            .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB))
-            .filter(or_(DR.state.is_(None), DR.state != DagRunState.QUEUED))
+            .join(TI.dag_run)
+            .options(eagerload(TI.dag_run))
+            .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state != DagRunState.QUEUED)
             .join(TI.dag_model)
             .filter(not_(DM.is_paused))
             .filter(TI.state == State.SCHEDULED)
             .options(selectinload('dag_model'))
-            .order_by(-TI.priority_weight, TI.execution_date)
+            .order_by(-TI.priority_weight, DR.execution_date)
         )
         starved_pools = [pool_name for pool_name, stats in pools.items() if stats['open'] <= 0]
         if starved_pools:
@@ -559,11 +485,10 @@ class SchedulerJob(BaseJob):
             ti_primary_key_to_try_number_map[ti_key.primary] = ti_key.try_number
 
             self.log.info(
-                "Executor reports execution of %s.%s execution_date=%s "
-                "exited with status %s for try_number %s",
+                "Executor reports execution of %s.%s run_id=%s exited with status %s for try_number %s",
                 ti_key.dag_id,
                 ti_key.task_id,
-                ti_key.execution_date,
+                ti_key.run_id,
                 state,
                 ti_key.try_number,
             )
@@ -710,11 +635,6 @@ class SchedulerJob(BaseJob):
             self._emit_pool_metrics,
         )
 
-        timers.call_regular_interval(
-            conf.getfloat('scheduler', 'clean_tis_without_dagrun_interval', fallback=15.0),
-            self._clean_tis_without_dagrun,
-        )
-
         for loop_count in itertools.count(start=1):
             with Stats.timer() as timer:
 
@@ -765,35 +685,6 @@ class SchedulerJob(BaseJob):
                 )
                 break
 
-    @provide_session
-    def _clean_tis_without_dagrun(self, session):
-        with prohibit_commit(session) as guard:
-            try:
-                self._change_state_for_tis_without_dagrun(
-                    old_states=[State.UP_FOR_RETRY], new_state=State.FAILED, session=session
-                )
-
-                self._change_state_for_tis_without_dagrun(
-                    old_states=[
-                        State.QUEUED,
-                        State.SCHEDULED,
-                        State.UP_FOR_RESCHEDULE,
-                        State.SENSING,
-                        State.DEFERRED,
-                    ],
-                    new_state=State.NONE,
-                    session=session,
-                )
-
-                guard.commit()
-            except OperationalError as e:
-                if is_lock_not_available_error(error=e):
-                    self.log.debug("Lock held by another Scheduler")
-                    session.rollback()
-                else:
-                    raise
-            guard.commit()
-
     def _do_scheduling(self, session) -> int:
         """
         This function is where the main scheduling decisions take places. It:
@@ -1052,7 +943,7 @@ class SchedulerJob(BaseJob):
             callback_to_execute = DagCallbackRequest(
                 full_filepath=dag.fileloc,
                 dag_id=dag.dag_id,
-                execution_date=dag_run.execution_date,
+                run_id=dag_run.run_id,
                 is_failure_callback=True,
                 msg='timed_out',
             )
@@ -1182,7 +1073,7 @@ class SchedulerJob(BaseJob):
                             DagRun.run_type != DagRunType.BACKFILL_JOB,
                             DagRun.state == State.RUNNING,
                         )
-                        .options(load_only(TI.dag_id, TI.task_id, TI.execution_date))
+                        .options(load_only(TI.dag_id, TI.task_id, TI.run_id))
                     )
 
                     # Lock these rows, so that another scheduler can't try and adopt these too
diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py
index fd740ac..bc68daf 100644
--- a/airflow/kubernetes/kubernetes_helper_functions.py
+++ b/airflow/kubernetes/kubernetes_helper_functions.py
@@ -18,7 +18,7 @@
 import logging
 from typing import Dict, Optional
 
-from dateutil import parser
+import pendulum
 from slugify import slugify
 
 from airflow.models.taskinstance import TaskInstanceKey
@@ -62,6 +62,26 @@ def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey]
     dag_id = annotations['dag_id']
     task_id = annotations['task_id']
     try_number = int(annotations['try_number'])
-    execution_date = parser.parse(annotations['execution_date'])
+    run_id = annotations.get('run_id')
+    if not run_id and 'execution_date' in annotations:
+        # Compat: Look up the run_id from the TI table!
+        from airflow.models.dagrun import DagRun
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import Session
 
-    return TaskInstanceKey(dag_id, task_id, execution_date, try_number)
+        execution_date = pendulum.parse(annotations['execution_date'])
+        # Do _not_ use create-session, we don't want to expunge
+        session = Session()
+
+        run_id: str = (
+            session.query(TaskInstance.run_id)
+            .join(TaskInstance.dag_run)
+            .filter(
+                TaskInstance.dag_id == dag_id,
+                TaskInstance.task_id == task_id,
+                DagRun.execution_date == execution_date,
+            )
+            .scalar()
+        )
+
+    return TaskInstanceKey(dag_id, task_id, run_id, try_number)
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index c7611f6..f997264 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -332,12 +332,13 @@ class PodGenerator:
         pod_id: str,
         try_number: int,
         kube_image: str,
-        date: datetime.datetime,
+        date: Optional[datetime.datetime],
         args: List[str],
         pod_override_object: Optional[k8s.V1Pod],
         base_worker_pod: k8s.V1Pod,
         namespace: str,
         scheduler_job_id: int,
+        run_id: Optional[str] = None,
     ) -> k8s.V1Pod:
         """
         Construct a pod by gathering and consolidating the configuration from 3 places:
@@ -352,25 +353,32 @@ class PodGenerator:
         except Exception:
             image = kube_image
 
+        annotations = {
+            'dag_id': dag_id,
+            'task_id': task_id,
+            'try_number': str(try_number),
+        }
+        labels = {
+            'airflow-worker': make_safe_label_value(str(scheduler_job_id)),
+            'dag_id': make_safe_label_value(dag_id),
+            'task_id': make_safe_label_value(task_id),
+            'try_number': str(try_number),
+            'airflow_version': airflow_version.replace('+', '-'),
+            'kubernetes_executor': 'True',
+        }
+        if date:
+            annotations['execution_date'] = date.isoformat()
+            labels['execution_date'] = datetime_to_label_safe_datestring(date)
+        if run_id:
+            annotations['run_id'] = run_id
+            labels['run_id'] = make_safe_label_value(run_id)
+
         dynamic_pod = k8s.V1Pod(
             metadata=k8s.V1ObjectMeta(
                 namespace=namespace,
-                annotations={
-                    'dag_id': dag_id,
-                    'task_id': task_id,
-                    'execution_date': date.isoformat(),
-                    'try_number': str(try_number),
-                },
+                annotations=annotations,
                 name=PodGenerator.make_unique_pod_id(pod_id),
-                labels={
-                    'airflow-worker': make_safe_label_value(str(scheduler_job_id)),
-                    'dag_id': make_safe_label_value(dag_id),
-                    'task_id': make_safe_label_value(task_id),
-                    'execution_date': datetime_to_label_safe_datestring(date),
-                    'try_number': str(try_number),
-                    'airflow_version': airflow_version.replace('+', '-'),
-                    'kubernetes_executor': 'True',
-                },
+                labels=labels,
             ),
             spec=k8s.V1PodSpec(
                 containers=[
diff --git a/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
new file mode 100644
index 0000000..8c62101
--- /dev/null
+++ b/airflow/migrations/versions/7b2661a43ba3_taskinstance_keyed_to_dagrun.py
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""TaskInstance keyed to DagRun
+
+Revision ID: 7b2661a43ba3
+Revises: 142555e44c17
+Create Date: 2021-07-15 15:26:12.710749
+
+"""
+
+from collections import defaultdict
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy.sql import and_, column, select, table
+
+from airflow.models.base import COLLATION_ARGS
+
+ID_LEN = 250
+
+# revision identifiers, used by Alembic.
+revision = '7b2661a43ba3'
+down_revision = '142555e44c17'
+branch_labels = None
+depends_on = None
+
+
+def _mssql_datetime():
+    from sqlalchemy.dialects import mssql
+
+    return mssql.DATETIME2(precision=6)
+
+
+# Just Enough Table to run the conditions for update.
+task_instance = table(
+    'task_instance',
+    column('task_id', sa.String),
+    column('dag_id', sa.String),
+    column('run_id', sa.String),
+    column('execution_date', sa.TIMESTAMP),
+)
+task_reschedule = table(
+    'task_reschedule',
+    column('task_id', sa.String),
+    column('dag_id', sa.String),
+    column('run_id', sa.String),
+    column('execution_date', sa.TIMESTAMP),
+)
+dag_run = table(
+    'dag_run',
+    column('dag_id', sa.String),
+    column('run_id', sa.String),
+    column('execution_date', sa.TIMESTAMP),
+)
+
+
+def get_table_constraints(conn, table_name):
+    """
+    This function return primary and unique constraint
+    along with column name. Some tables like `task_instance`
+    is missing the primary key constraint name and the name is
+    auto-generated by the SQL server. so this function helps to
+    retrieve any primary or unique constraint name.
+    :param conn: sql connection object
+    :param table_name: table name
+    :return: a dictionary of ((constraint name, constraint type), column name) of table
+    :rtype: defaultdict(list)
+    """
+    query = """SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
+     FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
+     JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE AS ccu ON ccu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
+     WHERE tc.TABLE_NAME = '{table_name}' AND
+     (tc.CONSTRAINT_TYPE = 'PRIMARY KEY' or UPPER(tc.CONSTRAINT_TYPE) = 'UNIQUE')
+    """.format(
+        table_name=table_name
+    )
+    result = conn.execute(query).fetchall()
+    constraint_dict = defaultdict(lambda: defaultdict(list))
+    for constraint, constraint_type, col_name in result:
+        constraint_dict[constraint_type][constraint].append(col_name)
+    return constraint_dict
+
+
+def upgrade():
+    """Apply TaskInstance keyed to DagRun"""
+    conn = op.get_bind()
+    dialect_name = conn.dialect.name
+
+    run_id_col_type = sa.String(length=ID_LEN, **COLLATION_ARGS)
+
+    if dialect_name == 'sqlite':
+        naming_convention = {
+            "uq": "%(table_name)s_%(column_0_N_name)s_key",
+        }
+        with op.batch_alter_table('dag_run', naming_convention=naming_convention, recreate="always"):
+            # The naming_convention force the previously un-named UNIQUE constraints to have the right name --
+            # but we still need to enter the context manager to trigger it
+            pass
+    elif dialect_name == 'mysql':
+        with op.batch_alter_table('dag_run') as batch_op:
+            batch_op.alter_column('dag_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
+            batch_op.alter_column('run_id', existing_type=sa.String(length=ID_LEN), type_=run_id_col_type)
+            batch_op.drop_constraint('dag_id', 'unique')
+            batch_op.drop_constraint('dag_id_2', 'unique')
+            batch_op.create_unique_constraint(
+                'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date']
+            )
+            batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])
+    elif dialect_name == 'mssql':
+
+        # _Somehow_ mssql was missing these constraints entirely!
+        with op.batch_alter_table('dag_run') as batch_op:
+            batch_op.create_unique_constraint(
+                'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date']
+            )
+            batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])
+
+    # First create column nullable
+    op.add_column('task_instance', sa.Column('run_id', type_=run_id_col_type, nullable=True))
+    op.add_column('task_reschedule', sa.Column('run_id', type_=run_id_col_type, nullable=True))
+
+    # Then update the new column by selecting the right value from DagRun
+    update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.run_id)
+    op.execute(update_query)
+
+    #
+    # TaskReschedule has a FK to TaskInstance, so we have to update that before
+    # we can drop the TI.execution_date column
+
+    update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.run_id)
+    op.execute(update_query)
+
+    with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
+        batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False)
+
+        batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 'foreignkey')
+        if dialect_name == "mysql":
+            # Mysql creates an index and a constraint -- we have to drop both
+            batch_op.drop_index('task_reschedule_dag_task_date_fkey')
+        batch_op.drop_index('idx_task_reschedule_dag_task_date')
+
+    with op.batch_alter_table('task_instance', schema=None) as batch_op:
+        # Then make it non-nullable
+        batch_op.alter_column('run_id', existing_type=run_id_col_type, existing_nullable=True, nullable=False)
+
+        # TODO: Is this right for non-postgres?
+        if dialect_name == 'mssql':
+            constraints = get_table_constraints(conn, "task_instance")
+            pk, _ = constraints['PRIMARY KEY'].popitem()
+            batch_op.drop_constraint(pk, type_='primary')
+        elif dialect_name not in ('sqlite'):
+            batch_op.drop_constraint('task_instance_pkey', type_='primary')
+        batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'run_id'])
+
+        batch_op.drop_index('ti_dag_date')
+        batch_op.drop_index('ti_state_lkp')
+        batch_op.drop_column('execution_date')
+        batch_op.create_foreign_key(
+            'task_instance_dag_run_fkey',
+            'dag_run',
+            ['dag_id', 'run_id'],
+            ['dag_id', 'run_id'],
+            ondelete='CASCADE',
+        )
+
+        batch_op.create_index('ti_dag_run', ['dag_id', 'run_id'])
+        batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'run_id', 'state'])
+
+    with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
+        batch_op.drop_column('execution_date')
+        batch_op.create_index(
+            'idx_task_reschedule_dag_task_run',
+            ['dag_id', 'task_id', 'run_id'],
+            unique=False,
+        )
+        # _Now_ there is a unique constraint on the columns in TI we can re-create the FK from TaskReschedule
+        batch_op.create_foreign_key(
+            'task_reschedule_ti_fkey',
+            'task_instance',
+            ['dag_id', 'task_id', 'run_id'],
+            ['dag_id', 'task_id', 'run_id'],
+            ondelete='CASCADE',
+        )
+
+        # https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/mssqlserver-1785-database-engine-error?view=sql-server-ver15
+        ondelete = 'CASCADE' if dialect_name != 'mssql' else 'NO ACTION'
+        batch_op.create_foreign_key(
+            'task_reschedule_dr_fkey',
+            'dag_run',
+            ['dag_id', 'run_id'],
+            ['dag_id', 'run_id'],
+            ondelete=ondelete,
+        )
+
+
+def downgrade():
+    """Unapply TaskInstance keyed to DagRun"""
+    dialect_name = op.get_bind().dialect.name
+
+    if dialect_name == "mssql":
+        col_type = _mssql_datetime()
+    else:
+        col_type = sa.TIMESTAMP(timezone=True)
+
+    op.add_column('task_instance', sa.Column('execution_date', col_type, nullable=True))
+    op.add_column('task_reschedule', sa.Column('execution_date', col_type, nullable=True))
+
+    update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.execution_date)
+    op.execute(update_query)
+
+    update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.execution_date)
+    op.execute(update_query)
+
+    with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
+        batch_op.alter_column(
+            'execution_date', existing_type=col_type, existing_nullable=True, nullable=False
+        )
+
+        # Can't drop PK index while there is a FK referencing it
+        batch_op.drop_constraint('task_reschedule_ti_fkey')
+        batch_op.drop_constraint('task_reschedule_dr_fkey')
+        batch_op.drop_index('idx_task_reschedule_dag_task_run')
+
+    with op.batch_alter_table('task_instance', schema=None) as batch_op:
+        batch_op.alter_column(
+            'execution_date', existing_type=col_type, existing_nullable=True, nullable=False
+        )
+
+        batch_op.drop_constraint('task_instance_pkey', type_='primary')
+        batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'execution_date'])
+
+        batch_op.drop_constraint('task_instance_dag_run_fkey', type_='foreignkey')
+        batch_op.drop_index('ti_dag_run')
+        batch_op.drop_index('ti_state_lkp')
+        batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state'])
+        batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
+
+        batch_op.drop_column('run_id')
+
+    with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
+        batch_op.drop_column('run_id')
+        batch_op.create_index(
+            'idx_task_reschedule_dag_task_date',
+            ['dag_id', 'task_id', 'execution_date'],
+            unique=False,
+        )
+        # Can only create FK once there is an index on these columns
+        batch_op.create_foreign_key(
+            'task_reschedule_dag_task_date_fkey',
+            'task_instance',
+            ['dag_id', 'task_id', 'execution_date'],
+            ['dag_id', 'task_id', 'execution_date'],
+            ondelete='CASCADE',
+        )
+
+
+def _multi_table_update(dialect_name, target, column):
+    condition = dag_run.c.dag_id == target.c.dag_id
+    if column == target.c.run_id:
+        condition = and_(condition, dag_run.c.execution_date == target.c.execution_date)
+    else:
+        condition = and_(condition, dag_run.c.run_id == target.c.run_id)
+
+    if dialect_name == "sqlite":
+        # Most SQLite versions don't support multi table update (and SQLA doesn't know about it anyway), so we
+        # need to do a Correlated subquery update
+        sub_q = select([dag_run.c[column.name]]).where(condition)
+
+        return target.update().values({column: sub_q})
+    else:
+        return target.update().where(condition).values({column: dag_run.c[column.name]})
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 60d6621..96dcbd8 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -48,6 +48,7 @@ import attr
 import jinja2
 from dateutil.relativedelta import relativedelta
 from sqlalchemy.orm import Session
+from sqlalchemy.orm.exc import NoResultFound
 
 import airflow.templates
 from airflow.compat.functools import cached_property
@@ -1284,6 +1285,7 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         dag: DAG = self._dag
         return list(map(lambda task_id: dag.task_dict[task_id], self.get_flat_relative_ids(upstream)))
 
+    @provide_session
     def run(
         self,
         start_date: Optional[datetime] = None,
@@ -1291,17 +1293,48 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         ignore_first_depends_on_past: bool = True,
         ignore_ti_state: bool = False,
         mark_success: bool = False,
+        test_mode: bool = False,
+        session: Session = None,
     ) -> None:
         """Run a set of task instances for a date range."""
+        from airflow.models import DagRun
+        from airflow.utils.types import DagRunType
+
         start_date = start_date or self.start_date
         end_date = end_date or self.end_date or timezone.utcnow()
 
         for info in self.dag.iter_dagrun_infos_between(start_date, end_date, align=False):
             ignore_depends_on_past = info.logical_date == start_date and ignore_first_depends_on_past
-            TaskInstance(self, info.logical_date).run(
+            try:
+                dag_run = (
+                    session.query(DagRun)
+                    .filter(
+                        DagRun.dag_id == self.dag_id,
+                        DagRun.execution_date == info.logical_date,
+                    )
+                    .one()
+                )
+                ti = TaskInstance(self, run_id=dag_run.run_id)
+            except NoResultFound:
+                # This is _mostly_ only used in tests
+                dr = DagRun(
+                    dag_id=self.dag_id,
+                    run_id=DagRun.generate_run_id(DagRunType.MANUAL, info.logical_date),
+                    run_type=DagRunType.MANUAL,
+                    execution_date=info.logical_date,
+                    data_interval=info.data_interval,
+                )
+                ti = TaskInstance(self, run_id=None)
+                ti.dag_run = dr
+                session.add(dr)
+                session.flush()
+
+            ti.run(
                 mark_success=mark_success,
                 ignore_depends_on_past=ignore_depends_on_past,
                 ignore_ti_state=ignore_ti_state,
+                test_mode=test_mode,
+                session=session,
             )
 
     def dry_run(self) -> None:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 21f3d4e..7743399 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -949,7 +949,7 @@ class DAG(LoggingMixin):
         callback = self.on_success_callback if success else self.on_failure_callback
         if callback:
             self.log.info('Executing dag callback function: %s', callback)
-            tis = dagrun.get_task_instances()
+            tis = dagrun.get_task_instances(session=session)
             ti = tis[-1]  # get first TaskInstance of DagRun
             ti.task = self.get_task(ti.task_id)
             context = ti.get_template_context(session=session)
@@ -1163,6 +1163,7 @@ class DAG(LoggingMixin):
                 task_ids=None,
                 start_date=start_date,
                 end_date=end_date,
+                run_id=None,
                 state=state,
                 include_subdags=False,
                 include_parentdag=False,
@@ -1171,7 +1172,8 @@ class DAG(LoggingMixin):
                 as_pk_tuple=False,
                 session=session,
             )
-            .order_by(TaskInstance.execution_date)
+            .join(TaskInstance.dag_run)
+            .order_by(DagRun.execution_date)
             .all()
         )
 
@@ -1182,6 +1184,7 @@ class DAG(LoggingMixin):
         task_ids,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
+        run_id: None,
         state: Union[str, List[str]],
         include_subdags: bool,
         include_parentdag: bool,
@@ -1203,6 +1206,7 @@ class DAG(LoggingMixin):
         task_ids,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
+        run_id: Optional[str],
         state: Union[str, List[str]],
         include_subdags: bool,
         include_parentdag: bool,
@@ -1223,6 +1227,7 @@ class DAG(LoggingMixin):
         task_ids,
         start_date: Optional[datetime],
         end_date: Optional[datetime],
+        run_id: Optional[str],
         state: Union[str, List[str]],
         include_subdags: bool,
         include_parentdag: bool,
@@ -1247,9 +1252,10 @@ class DAG(LoggingMixin):
 
         # Do we want full objects, or just the primary columns?
         if as_pk_tuple:
-            tis = session.query(TI.dag_id, TI.task_id, TI.execution_date)
+            tis = session.query(TI.dag_id, TI.task_id, TI.run_id)
         else:
             tis = session.query(TaskInstance)
+        tis = tis.join(TaskInstance.dag_run)
 
         if include_subdags:
             # Crafting the right filter for dag_id and task_ids combo
@@ -1261,15 +1267,17 @@ class DAG(LoggingMixin):
             tis = tis.filter(or_(*conditions))
         else:
             tis = tis.filter(TaskInstance.dag_id == self.dag_id, TaskInstance.task_id.in_(self.task_ids))
+        if run_id:
+            tis = tis.filter(TaskInstance.run_id == run_id)
         if start_date:
-            tis = tis.filter(TaskInstance.execution_date >= start_date)
+            tis = tis.filter(DagRun.execution_date >= start_date)
         if task_ids:
             tis = tis.filter(TaskInstance.task_id.in_(task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
             end_date = end_date or timezone.utcnow()
-            tis = tis.filter(TaskInstance.execution_date <= end_date)
+            tis = tis.filter(DagRun.execution_date <= end_date)
 
         if state:
             if isinstance(state, str):
@@ -1301,6 +1309,7 @@ class DAG(LoggingMixin):
                     task_ids=task_ids,
                     start_date=start_date,
                     end_date=end_date,
+                    run_id=None,
                     state=state,
                     include_subdags=include_subdags,
                     include_parentdag=False,
@@ -1353,10 +1362,14 @@ class DAG(LoggingMixin):
                         )
                     )
                 ti.render_templates()
-                external_tis = session.query(TI).filter(
-                    TI.dag_id == task.external_dag_id,
-                    TI.task_id == task.external_task_id,
-                    TI.execution_date == pendulum.parse(task.execution_date),
+                external_tis = (
+                    session.query(TI)
+                    .join(TI.dag_run)
+                    .filter(
+                        TI.dag_id == task.external_dag_id,
+                        TI.task_id == task.external_task_id,
+                        DagRun.execution_date == pendulum.parse(task.execution_date),
+                    )
                 )
 
                 for tii in external_tis:
@@ -1373,8 +1386,9 @@ class DAG(LoggingMixin):
                     result.update(
                         downstream._get_task_instances(
                             task_ids=None,
-                            start_date=tii.execution_date,
-                            end_date=tii.execution_date,
+                            run_id=tii.run_id,
+                            start_date=None,
+                            end_date=None,
                             state=state,
                             include_subdags=include_subdags,
                             include_dependent_dags=include_dependent_dags,
@@ -1408,7 +1422,7 @@ class DAG(LoggingMixin):
             return result
         elif result:
             # We've been asked for objects, lets combine it all back in to a result set
-            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.execution_date)
+            tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id)
 
             tis = session.query(TI).filter(TI.filter_for_tis(result))
         elif exclude_task_ids:
@@ -1667,6 +1681,7 @@ class DAG(LoggingMixin):
             task_ids=task_ids,
             start_date=start_date,
             end_date=end_date,
+            run_id=None,
             state=state,
             include_subdags=include_subdags,
             include_parentdag=include_parentdag,
@@ -2267,7 +2282,7 @@ class DAG(LoggingMixin):
                         orm_dag.tags.append(dag_tag_orm)
                         session.add(dag_tag_orm)
 
-        DagCode.bulk_sync_to_db(filelocs)
+        DagCode.bulk_sync_to_db(filelocs, session=session)
 
         # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
         # decide when to commit
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 7979434..e8f5f98 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -15,20 +15,21 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import warnings
 from datetime import datetime
 from typing import TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
 
 from sqlalchemy import Boolean, Column, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_
 from sqlalchemy.exc import IntegrityError
 from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import backref, relationship, synonym
+from sqlalchemy.orm import joinedload, relationship, synonym
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import expression
 
 from airflow import settings
 from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import AirflowException, TaskNotFound
-from airflow.models.base import ID_LEN, Base
+from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.stats import Stats
 from airflow.ti_deps.dep_context import DepContext
@@ -65,13 +66,13 @@ class DagRun(Base, LoggingMixin):
     __NO_VALUE = object()
 
     id = Column(Integer, primary_key=True)
-    dag_id = Column(String(ID_LEN))
+    dag_id = Column(String(ID_LEN, **COLLATION_ARGS))
     queued_at = Column(UtcDateTime)
     execution_date = Column(UtcDateTime, default=timezone.utcnow)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
     _state = Column('state', String(50), default=State.QUEUED)
-    run_id = Column(String(ID_LEN))
+    run_id = Column(String(ID_LEN, **COLLATION_ARGS))
     creating_job_id = Column(Integer)
     external_trigger = Column(Boolean, default=True)
     run_type = Column(String(50), nullable=False)
@@ -87,17 +88,12 @@ class DagRun(Base, LoggingMixin):
 
     __table_args__ = (
         Index('dag_id_state', dag_id, _state),
-        UniqueConstraint('dag_id', 'execution_date'),
-        UniqueConstraint('dag_id', 'run_id'),
+        UniqueConstraint('dag_id', 'execution_date', name='dag_run_dag_id_execution_date_key'),
+        UniqueConstraint('dag_id', 'run_id', name='dag_run_dag_id_run_id_key'),
         Index('idx_last_scheduling_decision', last_scheduling_decision),
     )
 
-    task_instances = relationship(
-        TI,
-        primaryjoin=and_(TI.dag_id == dag_id, TI.execution_date == execution_date),
-        foreign_keys=(dag_id, execution_date),
-        backref=backref('dag_run', uselist=False),
-    )
+    task_instances = relationship(TI, back_populates="dag_run")
 
     DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
         'scheduler',
@@ -303,9 +299,13 @@ class DagRun(Base, LoggingMixin):
         self, state: Optional[Iterable[TaskInstanceState]] = None, session=None
     ) -> Iterable[TI]:
         """Returns the task instances for this dag run"""
-        tis = session.query(TI).filter(
-            TI.dag_id == self.dag_id,
-            TI.execution_date == self.execution_date,
+        tis = (
+            session.query(TI)
+            .options(joinedload(TI.dag_run))
+            .filter(
+                TI.dag_id == self.dag_id,
+                TI.run_id == self.run_id,
+            )
         )
 
         if state:
@@ -338,8 +338,8 @@ class DagRun(Base, LoggingMixin):
         """
         return (
             session.query(TI)
-            .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id)
-            .first()
+            .filter(TI.dag_id == self.dag_id, TI.run_id == self.run_id, TI.task_id == task_id)
+            .one_or_none()
         )
 
     def get_dag(self) -> "DAG":
@@ -436,7 +436,7 @@ class DagRun(Base, LoggingMixin):
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
-                    execution_date=self.execution_date,
+                    run_id=self.run_id,
                     is_failure_callback=True,
                     msg='task_failure',
                 )
@@ -451,7 +451,7 @@ class DagRun(Base, LoggingMixin):
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
-                    execution_date=self.execution_date,
+                    run_id=self.run_id,
                     is_failure_callback=False,
                     msg='success',
                 )
@@ -472,7 +472,7 @@ class DagRun(Base, LoggingMixin):
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
-                    execution_date=self.execution_date,
+                    run_id=self.run_id,
                     is_failure_callback=True,
                     msg='all_tasks_deadlocked',
                 )
@@ -675,7 +675,7 @@ class DagRun(Base, LoggingMixin):
 
             if task.task_id not in task_ids:
                 Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
-                ti = TI(task, self.execution_date)
+                ti = TI(task, execution_date=None, run_id=self.run_id)
                 task_instance_mutation_hook(ti)
                 session.add(ti)
 
@@ -683,9 +683,7 @@ class DagRun(Base, LoggingMixin):
             session.flush()
         except IntegrityError as err:
             self.log.info(str(err))
-            self.log.info(
-                'Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.'
-            )
+            self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id)
             self.log.info('Doing session rollback.')
             # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
             session.rollback()
@@ -695,6 +693,7 @@ class DagRun(Base, LoggingMixin):
         """
         Get a single DAG Run
 
+        :meta private:
         :param session: Sqlalchemy ORM Session
         :type session: Session
         :param dag_id: DAG ID
@@ -705,6 +704,11 @@ class DagRun(Base, LoggingMixin):
             if one exists. None otherwise.
         :rtype: airflow.models.DagRun
         """
+        warnings.warn(
+            "This method is deprecated. Please use SQLAlchemy directly",
+            DeprecationWarning,
+            stacklevel=2,
+        )
         return (
             session.query(DagRun)
             .filter(
@@ -770,7 +774,7 @@ class DagRun(Base, LoggingMixin):
                 session.query(TI)
                 .filter(
                     TI.dag_id == self.dag_id,
-                    TI.execution_date == self.execution_date,
+                    TI.run_id == self.run_id,
                     TI.task_id.in_(schedulable_ti_ids),
                 )
                 .update({TI.state: State.SCHEDULED}, synchronize_session=False)
@@ -782,7 +786,7 @@ class DagRun(Base, LoggingMixin):
                 session.query(TI)
                 .filter(
                     TI.dag_id == self.dag_id,
-                    TI.execution_date == self.execution_date,
+                    TI.run_id == self.run_id,
                     TI.task_id.in_(dummy_ti_ids),
                 )
                 .update(
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 489da52..5cd50a3 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -16,7 +16,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from typing import Iterable, Union
+import warnings
+from typing import TYPE_CHECKING, Iterable, Union
 
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
@@ -24,6 +25,12 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from sqlalchemy import Session
+
+    from airflow.models import DagRun
+    from airflow.models.baseoperator import BaseOperator
+
 # The key used by SkipMixin to store XCom data.
 XCOM_SKIPMIXIN_KEY = "skipmixin_key"
 
@@ -37,44 +44,31 @@ XCOM_SKIPMIXIN_FOLLOWED = "followed"
 class SkipMixin(LoggingMixin):
     """A Mixin to skip Tasks Instances"""
 
-    def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
+    def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator]", session: "Session"):
         """Used internally to set state of task instances to skipped from the same dag run."""
         task_ids = [d.task_id for d in tasks]
         now = timezone.utcnow()
 
-        if dag_run:
-            session.query(TaskInstance).filter(
-                TaskInstance.dag_id == dag_run.dag_id,
-                TaskInstance.execution_date == dag_run.execution_date,
-                TaskInstance.task_id.in_(task_ids),
-            ).update(
-                {
-                    TaskInstance.state: State.SKIPPED,
-                    TaskInstance.start_date: now,
-                    TaskInstance.end_date: now,
-                },
-                synchronize_session=False,
-            )
-        else:
-            if execution_date is None:
-                raise ValueError("Execution date is None and no dag run")
-
-            self.log.warning("No DAG RUN present this should not happen")
-            # this is defensive against dag runs that are not complete
-            for task in tasks:
-                ti = TaskInstance(task, execution_date=execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = now
-                ti.end_date = now
-                session.merge(ti)
+        session.query(TaskInstance).filter(
+            TaskInstance.dag_id == dag_run.dag_id,
+            TaskInstance.run_id == dag_run.run_id,
+            TaskInstance.task_id.in_(task_ids),
+        ).update(
+            {
+                TaskInstance.state: State.SKIPPED,
+                TaskInstance.start_date: now,
+                TaskInstance.end_date: now,
+            },
+            synchronize_session=False,
+        )
 
     @provide_session
     def skip(
         self,
-        dag_run,
-        execution_date,
-        tasks,
-        session=None,
+        dag_run: "DagRun",
+        execution_date: "timezone.DateTime",
+        tasks: "Iterable[BaseOperator]",
+        session: "Session" = None,
     ):
         """
         Sets tasks instances to skipped from the same dag run.
@@ -91,7 +85,32 @@ class SkipMixin(LoggingMixin):
         if not tasks:
             return
 
-        self._set_state_to_skipped(dag_run, execution_date, tasks, session)
+        if execution_date and not dag_run:
+            from airflow.models.dagrun import DagRun
+
+            warnings.warn(
+                "Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
+            dag_run = (
+                session.query(DagRun)
+                .filter(
+                    DagRun.dag_id == tasks[0].dag_id,
+                    DagRun.execution_date == execution_date,
+                )
+                .one()
+            )
+        elif execution_date and dag_run and execution_date != dag_run.execution_date:
+            raise ValueError(
+                "execution_date has a different value to  dag_run.execution_date -- please only pass dag_run"
+            )
+
+        if dag_run is None:
+            raise ValueError("dag_run is required")
+
+        self._set_state_to_skipped(dag_run, tasks, session)
         session.commit()
 
         # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
@@ -154,7 +173,7 @@ class SkipMixin(LoggingMixin):
 
             self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
             with create_session() as session:
-                self._set_state_to_skipped(dag_run, ti.execution_date, skip_tasks, session=session)
+                self._set_state_to_skipped(dag_run, skip_tasks, session=session)
                 # For some reason, session.commit() needs to happen before xcom_push.
                 # Otherwise the session is not committed.
                 session.commit()
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 56546e9..e7aaea1 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -23,7 +23,7 @@ import os
 import pickle
 import signal
 import warnings
-from collections import defaultdict, namedtuple
+from collections import defaultdict
 from datetime import datetime, timedelta
 from functools import partial
 from tempfile import NamedTemporaryFile
@@ -47,12 +47,13 @@ from sqlalchemy import (
     func,
     inspect,
     or_,
+    tuple_,
 )
+from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import reconstructor, relationship
-from sqlalchemy.orm.attributes import NO_VALUE
+from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql.elements import BooleanClauseList
-from sqlalchemy.sql.expression import tuple_
 from sqlalchemy.sql.sqltypes import BigInteger
 
 from airflow import settings
@@ -67,6 +68,7 @@ from airflow.exceptions import (
     AirflowSkipException,
     AirflowSmartSensorException,
     AirflowTaskTimeout,
+    DagRunNotFound,
     TaskDeferralError,
     TaskDeferred,
 )
@@ -90,7 +92,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.net import get_hostname
 from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
-from airflow.utils.session import provide_session
+from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
 from airflow.utils.state import DagRunState, State
 from airflow.utils.timeout import timeout
@@ -111,7 +113,7 @@ log = logging.getLogger(__name__)
 
 
 if TYPE_CHECKING:
-    from airflow.models.dag import DAG, DagModel
+    from airflow.models.dag import DAG, DagModel, DagRun
 
 
 @contextlib.contextmanager
@@ -202,14 +204,14 @@ def clear_task_instances(
             ti.external_executor_id = None
             session.merge(ti)
 
-        task_id_by_key[ti.dag_id][ti.execution_date][ti.try_number].add(ti.task_id)
+        task_id_by_key[ti.dag_id][ti.run_id][ti.try_number].add(ti.task_id)
 
     if task_id_by_key:
         # Clear all reschedules related to the ti to clear
 
         # This is an optimization for the common case where all tis are for a small number
-        # of dag_id, execution_date and try_number. Use a nested dict of dag_id,
-        # execution_date, try_number and task_id to construct the where clause in a
+        # of dag_id, run_id and try_number. Use a nested dict of dag_id,
+        # run_id, try_number and task_id to construct the where clause in a
         # hierarchical manner. This speeds up the delete statement by more than 40x for
         # large number of tis (50k+).
         conditions = or_(
@@ -217,16 +219,16 @@ def clear_task_instances(
                 TR.dag_id == dag_id,
                 or_(
                     and_(
-                        TR.execution_date == execution_date,
+                        TR.run_id == run_id,
                         or_(
                             and_(TR.try_number == try_number, TR.task_id.in_(task_ids))
                             for try_number, task_ids in task_tries.items()
                         ),
                     )
-                    for execution_date, task_tries in dates.items()
+                    for run_id, task_tries in run_ids.items()
                 ),
             )
-            for dag_id, dates in task_id_by_key.items()
+            for dag_id, run_ids in task_id_by_key.items()
         )
 
         delete_qry = TR.__table__.delete().where(conditions)
@@ -251,16 +253,16 @@ def clear_task_instances(
     if dag_run_state is not False and tis:
         from airflow.models.dagrun import DagRun  # Avoid circular import
 
-        dates_by_dag_id = defaultdict(set)
+        run_ids_by_dag_id = defaultdict(set)
         for instance in tis:
-            dates_by_dag_id[instance.dag_id].add(instance.execution_date)
+            run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
 
         drs = (
             session.query(DagRun)
             .filter(
                 or_(
-                    and_(DagRun.dag_id == dag_id, DagRun.execution_date.in_(dates))
-                    for dag_id, dates in dates_by_dag_id.items()
+                    and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
+                    for dag_id, run_ids in run_ids_by_dag_id.items()
                 )
             )
             .all()
@@ -277,22 +279,22 @@ class TaskInstanceKey(NamedTuple):
 
     dag_id: str
     task_id: str
-    execution_date: datetime
+    run_id: str
     try_number: int = 1
 
     @property
-    def primary(self) -> Tuple[str, str, datetime]:
+    def primary(self) -> Tuple[str, str, str]:
         """Return task instance primary key part of the key"""
-        return self.dag_id, self.task_id, self.execution_date
+        return self.dag_id, self.task_id, self.run_id
 
     @property
     def reduced(self) -> 'TaskInstanceKey':
         """Remake the key by subtracting 1 from try number to match in memory information"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, max(1, self.try_number - 1))
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1))
 
     def with_try_number(self, try_number: int) -> 'TaskInstanceKey':
         """Returns TaskInstanceKey with provided ``try_number``"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, try_number)
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number)
 
     @property
     def key(self) -> "TaskInstanceKey":
@@ -321,7 +323,7 @@ class TaskInstance(Base, LoggingMixin):
 
     task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
     dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
-    execution_date = Column(UtcDateTime, primary_key=True)
+    run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
     start_date = Column(UtcDateTime)
     end_date = Column(UtcDateTime)
     duration = Column(Float)
@@ -359,9 +361,9 @@ class TaskInstance(Base, LoggingMixin):
 
     __table_args__ = (
         Index('ti_dag_state', dag_id, state),
-        Index('ti_dag_date', dag_id, execution_date),
+        Index('ti_dag_run', dag_id, run_id),
         Index('ti_state', state),
-        Index('ti_state_lkp', dag_id, task_id, execution_date, state),
+        Index('ti_state_lkp', dag_id, task_id, run_id, state),
         Index('ti_pool', pool, state, priority_weight),
         Index('ti_job_id', job_id),
         Index('ti_trigger_id', trigger_id),
@@ -371,6 +373,12 @@ class TaskInstance(Base, LoggingMixin):
             name='task_instance_trigger_id_fkey',
             ondelete='CASCADE',
         ),
+        ForeignKeyConstraint(
+            [dag_id, run_id],
+            ["dag_run.dag_id", "dag_run.run_id"],
+            name='task_instance_dag_run_fkey',
+            ondelete="CASCADE",
+        ),
     )
 
     dag_model = relationship(
@@ -379,6 +387,7 @@ class TaskInstance(Base, LoggingMixin):
         foreign_keys=dag_id,
         uselist=False,
         innerjoin=True,
+        viewonly=True,
     )
 
     trigger = relationship(
@@ -389,27 +398,52 @@ class TaskInstance(Base, LoggingMixin):
         innerjoin=True,
     )
 
-    def __init__(self, task, execution_date: datetime, state: Optional[str] = None):
+    dag_run = relationship("DagRun", back_populates="task_instances")
+
+    execution_date = association_proxy("dag_run", "execution_date")
+
+    def __init__(
+        self, task, execution_date: Optional[datetime] = None, run_id: str = None, state: Optional[str] = None
+    ):
         super().__init__()
         self.dag_id = task.dag_id
         self.task_id = task.task_id
         self.refresh_from_task(task)
         self._log = logging.getLogger("airflow.task")
 
-        # make sure we have a localized execution_date stored in UTC
-        if execution_date and not timezone.is_localized(execution_date):
-            self.log.warning(
-                "execution date %s has no timezone information. Using default from dag or system",
-                execution_date,
-            )
-            if self.task.has_dag():
-                execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
-            else:
-                execution_date = timezone.make_aware(execution_date)
+        if run_id is None and execution_date is not None:
+            from airflow.models.dagrun import DagRun  # Avoid circular import
 
-            execution_date = timezone.convert_to_utc(execution_date)
+            warnings.warn(
+                "Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
+                DeprecationWarning,
+                # Stack level is 4 because SQLA adds some wrappers around the constructor
+                stacklevel=4,
+            )
+            # make sure we have a localized execution_date stored in UTC
+            if execution_date and not timezone.is_localized(execution_date):
+                self.log.warning(
+                    "execution date %s has no timezone information. Using default from dag or system",
+                    execution_date,
+                )
+                if self.task.has_dag():
+                    execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
+                else:
+                    execution_date = timezone.make_aware(execution_date)
+
+                execution_date = timezone.convert_to_utc(execution_date)
+            with create_session() as session:
+                run_id = (
+                    session.query(DagRun.run_id)
+                    .filter_by(dag_id=self.dag_id, execution_date=execution_date)
+                    .scalar()
+                )
+                if run_id is None:
+                    raise DagRunNotFound(
+                        f"DagRun for {self.dag_id!r} with date {execution_date} not found"
+                    ) from None
 
-        self.execution_date = execution_date
+        self.run_id = run_id
 
         self.try_number = 0
         self.unixname = getuser()
@@ -466,22 +500,6 @@ class TaskInstance(Base, LoggingMixin):
         """Setting Next Try Number"""
         return self._try_number + 1
 
-    @property
-    def run_id(self):
-        """Fetches the run_id from the associated DagRun"""
-        # TODO: Remove this once run_id is added as a column in TaskInstance
-
-        # IF we have pre-loaded it, just use that
-        info = inspect(self)
-        if info.attrs.dag_run.loaded_value is not NO_VALUE:
-            return self.dag_un.run_id
-        # _Don't_ use provide/create_session here, as we do not want to commit on this session (as this is
-        # called from the scheduler critical section)!
-        dag_run = self.get_dagrun(session=settings.Session())
-
-        if dag_run:
-            return dag_run.run_id
-
     def command_as_list(
         self,
         mark_success=False,
@@ -525,7 +543,6 @@ class TaskInstance(Base, LoggingMixin):
             self.dag_id,
             self.task_id,
             run_id=self.run_id,
-            execution_date=self.execution_date,
             mark_success=mark_success,
             ignore_all_deps=ignore_all_deps,
             ignore_task_deps=ignore_task_deps,
@@ -545,7 +562,6 @@ class TaskInstance(Base, LoggingMixin):
         dag_id: str,
         task_id: str,
         run_id: str = None,
-        execution_date: datetime = None,
         mark_success: bool = False,
         ignore_all_deps: bool = False,
         ignore_depends_on_past: bool = False,
@@ -562,8 +578,6 @@ class TaskInstance(Base, LoggingMixin):
         """
         Generates the shell command required to execute this task instance.
 
-        One of run_id or execution_date must be passed
-
         :param dag_id: DAG ID
         :type dag_id: str
         :param task_id: Task ID
@@ -601,13 +615,7 @@ class TaskInstance(Base, LoggingMixin):
         :return: shell command that can be used to run the task instance
         :rtype: list[str]
         """
-        cmd = ["airflow", "tasks", "run", dag_id, task_id]
-        if run_id:
-            cmd.append(run_id)
-        elif execution_date:
-            cmd.append(execution_date.isoformat())
-        else:
-            raise ValueError("One of run_id and execution_date must be provided")
+        cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
         if mark_success:
             cmd.extend(["--mark-success"])
         if pickle_id:
@@ -671,7 +679,7 @@ class TaskInstance(Base, LoggingMixin):
             .filter(
                 TaskInstance.dag_id == self.dag_id,
                 TaskInstance.task_id == self.task_id,
-                TaskInstance.execution_date == self.execution_date,
+                TaskInstance.run_id == self.run_id,
             )
             .all()
         )
@@ -711,11 +719,11 @@ class TaskInstance(Base, LoggingMixin):
         qry = session.query(TaskInstance).filter(
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.execution_date == self.execution_date,
+            TaskInstance.run_id == self.run_id,
         )
 
         if lock_for_update:
-            ti = qry.with_for_update().first()
+            ti: Optional[TaskInstance] = qry.with_for_update().first()
         else:
             ti = qry.first()
         if ti:
@@ -788,7 +796,7 @@ class TaskInstance(Base, LoggingMixin):
     @property
     def key(self) -> TaskInstanceKey:
         """Returns a tuple that identifies the task instance uniquely"""
-        return TaskInstanceKey(self.dag_id, self.task_id, self.execution_date, self.try_number)
+        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
 
     @provide_session
     def set_state(self, state: str, session=None):
@@ -839,7 +847,7 @@ class TaskInstance(Base, LoggingMixin):
         ti = session.query(func.count(TaskInstance.task_id)).filter(
             TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id.in_(task.downstream_task_ids),
-            TaskInstance.execution_date == self.execution_date,
+            TaskInstance.run_id == self.run_id,
             TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
         )
         count = ti[0][0]
@@ -1043,7 +1051,7 @@ class TaskInstance(Base, LoggingMixin):
                     yield dep_status
 
     def __repr__(self):
-        return f"<TaskInstance: {self.dag_id}.{self.task_id} {self.execution_date} [{self.state}]>"
+        return f"<TaskInstance: {self.dag_id}.{self.task_id} {self.run_id} [{self.state}]>"
 
     def next_retry_datetime(self):
         """
@@ -1093,13 +1101,16 @@ class TaskInstance(Base, LoggingMixin):
         :param session: SQLAlchemy ORM Session
         :return: DagRun
         """
+        info = inspect(self)
+        if info.attrs.dag_run.loaded_value is not NO_VALUE:
+            return self.dag_run
+
         from airflow.models.dagrun import DagRun  # Avoid circular import
 
-        dr = (
-            session.query(DagRun)
-            .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date)
-            .first()
-        )
+        dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
+
+        # Record it in the instance for next time. This means that `self.execution_date` will work correctly
+        set_committed_value(self, 'dag_run', dr)
 
         return dr
 
@@ -1287,20 +1298,22 @@ class TaskInstance(Base, LoggingMixin):
         self.job_id = job_id
         self.hostname = get_hostname()
         self.pid = os.getpid()
-        session.merge(self)
-        session.commit()
+        if not test_mode:
+            session.merge(self)
+            session.commit()
         actual_start_date = timezone.utcnow()
         Stats.incr(f'ti.start.{task.dag_id}.{task.task_id}')
         try:
             if not mark_success:
                 context = self.get_template_context()
                 self._prepare_and_execute_task_with_callbacks(context, task)
-            self.refresh_from_db(lock_for_update=True)
+            if not test_mode:
+                self.refresh_from_db(lock_for_update=True, session=session)
             self.state = State.SUCCESS
         except TaskDeferred as defer:
             # The task has signalled it wants to defer execution based on
             # a trigger.
-            self._defer_task(defer=defer)
+            self._defer_task(defer=defer, session=session)
             self.log.info(
                 'Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s',
                 self.dag_id,
@@ -1311,7 +1324,7 @@ class TaskInstance(Base, LoggingMixin):
             if not test_mode:
                 session.add(Log(self.state, self))
                 session.merge(self)
-            session.commit()
+                session.commit()
             return
         except AirflowSmartSensorException as e:
             self.log.info(e)
@@ -1321,29 +1334,31 @@ class TaskInstance(Base, LoggingMixin):
             # log only if exception has any arguments to prevent log flooding
             if e.args:
                 self.log.info(e)
-            self.refresh_from_db(lock_for_update=True)
+            if not test_mode:
+                self.refresh_from_db(lock_for_update=True, session=session)
             self.state = State.SKIPPED
         except AirflowRescheduleException as reschedule_exception:
-            self.refresh_from_db()
-            self._handle_reschedule(actual_start_date, reschedule_exception, test_mode)
+            self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
+            session.commit()
             return
         except (AirflowFailException, AirflowSensorTimeout) as e:
             # If AirflowFailException is raised, task should not retry.
             # If a sensor in reschedule mode reaches timeout, task should not retry.
-            self.refresh_from_db()
-            self.handle_failure(e, test_mode, force_fail=True, error_file=error_file)
+            self.handle_failure(e, test_mode, force_fail=True, error_file=error_file, session=session)
+            session.commit()
             raise
         except AirflowException as e:
-            self.refresh_from_db()
             # for case when task is marked as success/failed externally
             # current behavior doesn't hit the success callback
             if self.state in {State.SUCCESS, State.FAILED}:
                 return
             else:
-                self.handle_failure(e, test_mode, error_file=error_file)
+                self.handle_failure(e, test_mode, error_file=error_file, session=session)
+                session.commit()
                 raise
         except (Exception, KeyboardInterrupt) as e:
-            self.handle_failure(e, test_mode, error_file=error_file)
+            self.handle_failure(e, test_mode, error_file=error_file, session=session)
+            session.commit()
             raise
         finally:
             Stats.incr(f'ti.finish.{task.dag_id}.{task.task_id}.{self.state}')
@@ -1356,7 +1371,7 @@ class TaskInstance(Base, LoggingMixin):
             session.add(Log(self.state, self))
             session.merge(self)
 
-        session.commit()
+            session.commit()
 
     def _prepare_and_execute_task_with_callbacks(self, context, task):
         """Prepare Task for Execution"""
@@ -1613,6 +1628,7 @@ class TaskInstance(Base, LoggingMixin):
         # Don't record reschedule request in test mode
         if test_mode:
             return
+        self.refresh_from_db(session)
 
         self.end_date = timezone.utcnow()
         self.set_duration()
@@ -1621,7 +1637,7 @@ class TaskInstance(Base, LoggingMixin):
         session.add(
             TaskReschedule(
                 self.task,
-                self.execution_date,
+                self.run_id,
                 self._try_number,
                 actual_start_date,
                 self.end_date,
@@ -1662,6 +1678,8 @@ class TaskInstance(Base, LoggingMixin):
             # can send its runtime errors for access by failure callback
             if error_file:
                 set_error_file(error_file, error)
+        if not test_mode:
+            self.refresh_from_db(session)
 
         task = self.task
         self.end_date = timezone.utcnow()
@@ -1671,8 +1689,8 @@ class TaskInstance(Base, LoggingMixin):
         if not test_mode:
             session.add(Log(State.FAILED, self))
 
-        # Log failure duration
-        session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
+            # Log failure duration
+            session.add(TaskFail(task, self.execution_date, self.start_date, self.end_date))
 
         # Set state correctly and figure out how to log it and decide whether
         # to email
@@ -1702,7 +1720,7 @@ class TaskInstance(Base, LoggingMixin):
 
         if not test_mode:
             session.merge(self)
-        session.commit()
+            session.flush()
 
     @provide_session
     def handle_failure_with_callback(
@@ -1724,30 +1742,18 @@ class TaskInstance(Base, LoggingMixin):
 
         return self.task.retries and self.try_number <= self.max_tries
 
-    @provide_session
-    def get_template_context(self, session=None) -> Context:
+    def get_template_context(self, session: Session = None) -> Context:
         """Return TI Context"""
+        # Do not use provide_session here -- it expunges everything on exit!
+        if not session:
+            session = settings.Session()
         task = self.task
         from airflow import macros
 
         integrate_macros_plugins()
 
-        dag_run = self.get_dagrun()
-
-        # FIXME: Many tests don't create a DagRun. We should fix the tests.
-        if dag_run is None:
-            FakeDagRun = namedtuple(
-                "FakeDagRun",
-                # A minimal set of attributes to keep things working.
-                "conf data_interval_start data_interval_end external_trigger run_id",
-            )
-            dag_run = FakeDagRun(
-                conf=None,
-                data_interval_start=None,
-                data_interval_end=None,
-                external_trigger=False,
-                run_id="",
-            )
+        # Ensure that the dag_run is loaded -- otherwise `self.execution_date` may not work
+        dag_run = self.get_dagrun(session)
 
         params = {}  # type: Dict[str, Any]
         with contextlib.suppress(AttributeError):
@@ -1985,7 +1991,7 @@ class TaskInstance(Base, LoggingMixin):
                 replacement='prev_data_interval_start_success',
             ),
             'prev_start_date_success': lazy_object_proxy.Proxy(get_prev_start_date_success),
-            'run_id': dag_run.run_id,
+            'run_id': self.run_id,
             'task': task,
             'task_instance': self,
             'task_instance_key_str': f"{task.dag_id}__{task.task_id}__{ds_nodash}",
@@ -2005,11 +2011,12 @@ class TaskInstance(Base, LoggingMixin):
             'yesterday_ds_nodash': deprecated_proxy(get_yesterday_ds_nodash, key='yesterday_ds_nodash'),
         }
 
-    def get_rendered_template_fields(self):
+    @provide_session
+    def get_rendered_template_fields(self, session=None):
         """Fetch rendered template fields from DB"""
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
-        rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self)
+        rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
         if rendered_task_instance_fields:
             for field_name, rendered_value in rendered_task_instance_fields.items():
                 setattr(self.task, field_name, rendered_value)
@@ -2184,10 +2191,11 @@ class TaskInstance(Base, LoggingMixin):
         :param session: Sqlalchemy ORM Session
         :type session: Session
         """
-        if execution_date and execution_date < self.execution_date:
+        self_execution_date = self.get_dagrun(session).execution_date
+        if execution_date and execution_date < self_execution_date:
             raise ValueError(
                 'execution_date can not be in the past (current '
-                'execution_date is {}; received {})'.format(self.execution_date, execution_date)
+                'execution_date is {}; received {})'.format(self_execution_date, execution_date)
             )
 
         XCom.set(
@@ -2195,7 +2203,7 @@ class TaskInstance(Base, LoggingMixin):
             value=value,
             task_id=self.task_id,
             dag_id=self.dag_id,
-            execution_date=execution_date or self.execution_date,
+            execution_date=execution_date or self_execution_date,
             session=session,
         )
 
@@ -2242,8 +2250,10 @@ class TaskInstance(Base, LoggingMixin):
         if dag_id is None:
             dag_id = self.dag_id
 
+        execution_date = self.get_dagrun(session).execution_date
+
         query = XCom.get_many(
-            execution_date=self.execution_date,
+            execution_date=execution_date,
             key=key,
             dag_ids=dag_id,
             task_ids=task_ids,
@@ -2299,20 +2309,20 @@ class TaskInstance(Base, LoggingMixin):
         first = tis[0]
 
         dag_id = first.dag_id
-        execution_date = first.execution_date
+        run_id = first.run_id
         first_task_id = first.task_id
-        # Common path optimisations: when all TIs are for the same dag_id and execution_date, or same dag_id
+        # Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
         # and task_id -- this can be over 150x for huge numbers of TIs (20k+)
-        if all(t.dag_id == dag_id and t.execution_date == execution_date for t in tis):
+        if all(t.dag_id == dag_id and t.run_id == run_id for t in tis):
             return and_(
                 TaskInstance.dag_id == dag_id,
-                TaskInstance.execution_date == execution_date,
+                TaskInstance.run_id == run_id,
                 TaskInstance.task_id.in_(t.task_id for t in tis),
             )
         if all(t.dag_id == dag_id and t.task_id == first_task_id for t in tis):
             return and_(
                 TaskInstance.dag_id == dag_id,
-                TaskInstance.execution_date.in_(t.execution_date for t in tis),
+                TaskInstance.run_id.in_(t.run_id for t in tis),
                 TaskInstance.task_id == first_task_id,
             )
 
@@ -2321,12 +2331,12 @@ class TaskInstance(Base, LoggingMixin):
                 and_(
                     TaskInstance.dag_id == ti.dag_id,
                     TaskInstance.task_id == ti.task_id,
-                    TaskInstance.execution_date == ti.execution_date,
+                    TaskInstance.run_id == ti.run_id,
                 )
                 for ti in tis
             )
         else:
-            return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.execution_date).in_(
+            return tuple_(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id).in_(
                 [ti.key.primary for ti in tis]
             )
 
@@ -2346,7 +2356,7 @@ class SimpleTaskInstance:
     def __init__(self, ti: TaskInstance):
         self._dag_id: str = ti.dag_id
         self._task_id: str = ti.task_id
-        self._execution_date: datetime = ti.execution_date
+        self._run_id: datetime = ti.run_id
         self._start_date: datetime = ti.start_date
         self._end_date: datetime = ti.end_date
         self._try_number: int = ti.try_number
@@ -2371,8 +2381,8 @@ class SimpleTaskInstance:
         return self._task_id
 
     @property
-    def execution_date(self) -> datetime:
-        return self._execution_date
+    def run_id(self) -> str:
+        return self._run_id
 
     @property
     def start_date(self) -> datetime:
@@ -2410,36 +2420,10 @@ class SimpleTaskInstance:
     def executor_config(self):
         return self._executor_config
 
-    @provide_session
-    def construct_task_instance(self, session=None, lock_for_update=False) -> TaskInstance:
-        """
-        Construct a TaskInstance from the database based on the primary key
-
-        :param session: DB session.
-        :param lock_for_update: if True, indicates that the database should
-            lock the TaskInstance (issuing a FOR UPDATE clause) until the
-            session is committed.
-        :return: the task instance constructed
-        """
-        qry = session.query(TaskInstance).filter(
-            TaskInstance.dag_id == self._dag_id,
-            TaskInstance.task_id == self._task_id,
-            TaskInstance.execution_date == self._execution_date,
-        )
-
-        if lock_for_update:
-            ti = qry.with_for_update().first()
-        else:
-            ti = qry.first()
-        return ti
-
 
 STATICA_HACK = True
 globals()['kcah_acitats'[::-1].upper()] = False
 if STATICA_HACK:  # pragma: no cover
+    from airflow.job.base_job import BaseJob
 
-    from airflow.jobs.base_job import BaseJob
-    from airflow.models.dagrun import DagRun
-
-    TaskInstance.dag_run = relationship(DagRun)
     TaskInstance.queued_by_job = relationship(BaseJob)
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 293021c..55ef754 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -17,6 +17,8 @@
 # under the License.
 """TaskReschedule tracks rescheduled task instances."""
 from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.orm import relationship
 
 from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
 from airflow.utils.session import provide_session
@@ -31,7 +33,7 @@ class TaskReschedule(Base):
     id = Column(Integer, primary_key=True)
     task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
     dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
-    execution_date = Column(UtcDateTime, nullable=False)
+    run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
     try_number = Column(Integer, nullable=False)
     start_date = Column(UtcDateTime, nullable=False)
     end_date = Column(UtcDateTime, nullable=False)
@@ -39,19 +41,27 @@ class TaskReschedule(Base):
     reschedule_date = Column(UtcDateTime, nullable=False)
 
     __table_args__ = (
-        Index('idx_task_reschedule_dag_task_date', dag_id, task_id, execution_date, unique=False),
+        Index('idx_task_reschedule_dag_task_run', dag_id, task_id, run_id, unique=False),
         ForeignKeyConstraint(
-            [task_id, dag_id, execution_date],
-            ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'],
-            name='task_reschedule_dag_task_date_fkey',
+            [dag_id, task_id, run_id],
+            ['task_instance.dag_id', 'task_instance.task_id', 'task_instance.run_id'],
+            name='task_reschedule_ti_fkey',
+            ondelete='CASCADE',
+        ),
+        ForeignKeyConstraint(
+            [dag_id, run_id],
+            ['dag_run.dag_id', 'dag_run.run_id'],
+            name='task_reschedule_dr_fkey',
             ondelete='CASCADE',
         ),
     )
+    dag_run = relationship("DagRun")
+    execution_date = association_proxy("dag_run", "execution_date")
 
-    def __init__(self, task, execution_date, try_number, start_date, end_date, reschedule_date):
+    def __init__(self, task, run_id, try_number, start_date, end_date, reschedule_date):
         self.dag_id = task.dag_id
         self.task_id = task.task_id
-        self.execution_date = execution_date
+        self.run_id = run_id
         self.try_number = try_number
         self.start_date = start_date
         self.end_date = end_date
@@ -81,7 +91,7 @@ class TaskReschedule(Base):
         qry = session.query(TR).filter(
             TR.dag_id == task_instance.dag_id,
             TR.task_id == task_instance.task_id,
-            TR.execution_date == task_instance.execution_date,
+            TR.run_id == task_instance.run_id,
             TR.try_number == try_number,
         )
         if descending:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index e89d6a8..05c1342 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -34,6 +34,7 @@ from google.cloud.bigquery import TableReference
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, BaseOperatorLink
 from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import XCom
 from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
 from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
 from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
@@ -60,8 +61,12 @@ class BigQueryConsoleLink(BaseOperatorLink):
     name = 'BigQuery Console'
 
     def get_link(self, operator, dttm):
-        ti = TaskInstance(task=operator, execution_date=dttm)
-        job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id')
+        job_id = XCom.get_one(
+            dag_id=operator.dag.dag_id,
+            task_id=operator.task_id,
+            execution_date=dttm,
+            key='job_id',
+        )
         return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else ''
 
 
diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py
index f0fd0e0..0019c41 100644
--- a/airflow/sensors/base.py
+++ b/airflow/sensors/base.py
@@ -125,7 +125,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
                 "The mode must be one of {valid_modes},"
                 "'{d}.{t}'; received '{m}'.".format(
                     valid_modes=self.valid_modes,
-                    d=self.dag.dag_id if self.dag else "",
+                    d=self.dag.dag_id if self.has_dag() else "",
                     t=self.task_id,
                     m=self.mode,
                 )
diff --git a/airflow/sensors/smart_sensor.py b/airflow/sensors/smart_sensor.py
index ec6acef..1e8c827 100644
--- a/airflow/sensors/smart_sensor.py
+++ b/airflow/sensors/smart_sensor.py
@@ -15,8 +15,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-
 import datetime
 import json
 import logging
@@ -27,7 +25,7 @@ from time import sleep
 from sqlalchemy import and_, or_, tuple_
 
 from airflow.exceptions import AirflowException, AirflowTaskTimeout
-from airflow.models import BaseOperator, SensorInstance, SkipMixin, TaskInstance
+from airflow.models import BaseOperator, DagRun, SensorInstance, SkipMixin, TaskInstance
 from airflow.settings import LOGGING_CLASS_PATH
 from airflow.stats import Stats
 from airflow.utils import helpers, timezone
@@ -390,6 +388,7 @@ class SmartSensorOperator(BaseOperator, SkipMixin):
         :param sensor_works: Smart sensor internal object for a sensor task.
         :param session: The sqlalchemy session.
         """
+        DR = DagRun
         TI = TaskInstance
 
         def update_ti_hostname_with_count(count, sensor_works):
@@ -399,18 +398,17 @@ class SmartSensorOperator(BaseOperator, SkipMixin):
                     and_(
                         TI.dag_id == ti_key.dag_id,
                         TI.task_id == ti_key.task_id,
-                        TI.execution_date == ti_key.execution_date,
+                        DR.execution_date == ti_key.execution_date,
                     )
                     for ti_key in sensor_works
                 )
             else:
                 ti_keys = [(x.dag_id, x.task_id, x.execution_date) for x in sensor_works]
                 ti_filter = or_(
-                    tuple_(TI.dag_id, TI.task_id, TI.execution_date) == ti_key for ti_key in ti_keys
+                    tuple_(TI.dag_id, TI.task_id, DR.execution_date) == ti_key for ti_key in ti_keys
                 )
-            tis = session.query(TI).filter(ti_filter).all()
 
-            for ti in tis:
+            for ti in session.query(TI).join(TI.dag_run).filter(ti_filter):
                 ti.hostname = self.hostname
             session.commit()
 
diff --git a/airflow/sentry.py b/airflow/sentry.py
index 51fe26f..340b660 100644
--- a/airflow/sentry.py
+++ b/airflow/sentry.py
@@ -130,13 +130,9 @@ if conf.getboolean("sentry", 'sentry_on', fallback=False):
             """Function to add breadcrumbs inside of a task_instance."""
             if session is None:
                 return
-            execution_date = task_instance.execution_date
-            task = task_instance.task
-            dag = task.dag
-            task_instances = dag.get_task_instances(
+            dr = task_instance.get_dagrun(session)
+            task_instances = dr.get_task_instances(
                 state={State.SUCCESS, State.FAILED},
-                end_date=execution_date,
-                start_date=execution_date,
                 session=session,
             )
 
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index b8b1f3a..6c747c7 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -16,11 +16,16 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import pendulum
+from typing import TYPE_CHECKING, List
+
 from sqlalchemy.orm.session import Session
 
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    from airflow.models.dagrun import DagRun
+    from airflow.models.taskinstance import TaskInstance
+
 
 class DepContext:
     """
@@ -85,23 +90,16 @@ class DepContext:
         self.ignore_ti_state = ignore_ti_state
         self.finished_tasks = finished_tasks
 
-    def ensure_finished_tasks(self, dag, execution_date: pendulum.DateTime, session: Session):
+    def ensure_finished_tasks(self, dag_run: "DagRun", session: Session) -> "List[TaskInstance]":
         """
         This method makes sure finished_tasks is populated if it's currently None.
         This is for the strange feature of running tasks without dag_run.
 
-        :param dag: The DAG for which to find finished tasks
-        :type dag: airflow.models.DAG
-        :param execution_date: The execution_date to look for
-        :param session: Database session to use
+        :param dag_run: The DagRun for which to find finished tasks
+        :type dag_run: airflow.models.DagRun
         :return: A list of all the finished tasks of this DAG and execution_date
         :rtype: list[airflow.models.TaskInstance]
         """
         if self.finished_tasks is None:
-            self.finished_tasks = dag.get_task_instances(
-                start_date=execution_date,
-                end_date=execution_date,
-                state=State.finished,
-                session=session,
-            )
+            self.finished_tasks = dag_run.get_task_instances(state=State.finished, session=session)
         return self.finished_tasks
diff --git a/airflow/ti_deps/deps/dagrun_exists_dep.py b/airflow/ti_deps/deps/dagrun_exists_dep.py
index 6e00b83..0aa21e8 100644
--- a/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -29,27 +29,9 @@ class DagrunRunningDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context):
-        dag = ti.task.dag
-        dagrun = ti.get_dagrun(session)
-        if not dagrun:
-            # The import is needed here to avoid a circular dependency
-            from airflow.models.dagrun import DagRun
-
-            running_dagruns = DagRun.find(
-                dag_id=dag.dag_id, state=State.RUNNING, external_trigger=False, session=session
+        dr = ti.get_dagrun(session)
+        if dr.state != State.RUNNING:
+            yield self._failing_status(
+                reason="Task instance's dagrun was not in the 'running' state but in "
+                "the state '{}'.".format(dr.state)
             )
-
-            if len(running_dagruns) >= dag.max_active_runs:
-                reason = (
-                    "The maximum number of active dag runs ({}) for this task "
-                    "instance's DAG '{}' has been reached.".format(dag.max_active_runs, ti.dag_id)
-                )
-            else:
-                reason = "Unknown reason"
-            yield self._failing_status(reason=f"Task instance's dagrun did not exist: {reason}.")
-        else:
-            if dagrun.state != State.RUNNING:
-                yield self._failing_status(
-                    reason="Task instance's dagrun was not in the 'running' state but in "
-                    "the state '{}'.".format(dagrun.state)
-                )
diff --git a/airflow/ti_deps/deps/dagrun_id_dep.py b/airflow/ti_deps/deps/dagrun_id_dep.py
index 186ab7c..a609514 100644
--- a/airflow/ti_deps/deps/dagrun_id_dep.py
+++ b/airflow/ti_deps/deps/dagrun_id_dep.py
@@ -32,7 +32,7 @@ class DagrunIdDep(BaseTIDep):
     @provide_session
     def _get_dep_statuses(self, ti, session, dep_context=None):
         """
-        Determines if the DagRun ID is valid for scheduling from scheduler.
+        Determines if the DagRun is valid for scheduling from scheduler.
 
         :param ti: the task instance to get the dependency status for
         :type ti: airflow.models.TaskInstance
@@ -44,12 +44,7 @@ class DagrunIdDep(BaseTIDep):
         """
         dagrun = ti.get_dagrun(session)
 
-        if not dagrun or not dagrun.run_id or dagrun.run_type != DagRunType.BACKFILL_JOB:
-            yield self._passing_status(
-                reason=f"Task's DagRun doesn't exist or run_id is either NULL "
-                f"or run_type is not {DagRunType.BACKFILL_JOB}"
-            )
-        else:
+        if dagrun.run_type == DagRunType.BACKFILL_JOB:
             yield self._failing_status(
-                reason=f"Task's DagRun run_id is not NULL " f"and run type is {DagRunType.BACKFILL_JOB}"
+                reason=f"Task's DagRun run_type is {dagrun.run_type} and cannot be run by the scheduler"
             )
diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py
index 3d1bde9..e9df0ed 100644
--- a/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -39,7 +39,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
 
         upstream = ti.task.get_direct_relatives(upstream=True)
 
-        finished_tasks = dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session)
+        finished_tasks = dep_context.ensure_finished_tasks(ti.get_dagrun(session), session)
 
         finished_task_ids = {t.task_id for t in finished_tasks}
 
diff --git a/airflow/ti_deps/deps/runnable_exec_date_dep.py b/airflow/ti_deps/deps/runnable_exec_date_dep.py
index 3986ef1..0607c11 100644
--- a/airflow/ti_deps/deps/runnable_exec_date_dep.py
+++ b/airflow/ti_deps/deps/runnable_exec_date_dep.py
@@ -33,20 +33,21 @@ class RunnableExecDateDep(BaseTIDep):
 
         # don't consider runs that are executed in the future unless
         # specified by config and schedule_interval is None
-        if ti.execution_date > cur_date and not ti.task.dag.allow_future_exec_dates:
+        logical_date = ti.get_dagrun(session).execution_date
+        if logical_date > cur_date and not ti.task.dag.allow_future_exec_dates:
             yield self._failing_status(
                 reason="Execution date {} is in the future (the current "
-                "date is {}).".format(ti.execution_date.isoformat(), cur_date.isoformat())
+                "date is {}).".format(logical_date.isoformat(), cur_date.isoformat())
             )
 
-        if ti.task.end_date and ti.execution_date > ti.task.end_date:
+        if ti.task.end_date and logical_date > ti.task.end_date:
             yield self._failing_status(
                 reason="The execution date is {} but this is after the task's end date "
-                "{}.".format(ti.execution_date.isoformat(), ti.task.end_date.isoformat())
+                "{}.".format(logical_date.isoformat(), ti.task.end_date.isoformat())
             )
 
-        if ti.task.dag and ti.task.dag.end_date and ti.execution_date > ti.task.dag.end_date:
+        if ti.task.dag and ti.task.dag.end_date and logical_date > ti.task.dag.end_date:
             yield self._failing_status(
                 reason="The execution date is {} but this is after the task's DAG's "
-                "end date {}.".format(ti.execution_date.isoformat(), ti.task.dag.end_date.isoformat())
+                "end date {}.".format(logical_date.isoformat(), ti.task.dag.end_date.isoformat())
             )
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index f04cf31..5d72410 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -66,7 +66,7 @@ class TriggerRuleDep(BaseTIDep):
             return
         # see if the task name is in the task upstream for our task
         successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
-            ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.task.dag, ti.execution_date, session)
+            ti=ti, finished_tasks=dep_context.ensure_finished_tasks(ti.get_dagrun(session), session)
         )
 
         yield from self._evaluate_trigger_rule(
diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py
index 89ffe52..8ed587b 100644
--- a/airflow/utils/callback_requests.py
+++ b/airflow/utils/callback_requests.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from datetime import datetime
 from typing import Optional
 
 from airflow.models.taskinstance import SimpleTaskInstance
@@ -71,7 +70,7 @@ class DagCallbackRequest(CallbackRequest):
 
     :param full_filepath: File Path to use to run the callback
     :param dag_id: DAG ID
-    :param execution_date: Execution Date for the DagRun
+    :param run_id: Run ID for the DagRun
     :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
     :param msg: Additional Message that can be used for logging
     """
@@ -80,13 +79,13 @@ class DagCallbackRequest(CallbackRequest):
         self,
         full_filepath: str,
         dag_id: str,
-        execution_date: datetime,
+        run_id: str,
         is_failure_callback: Optional[bool] = True,
         msg: Optional[str] = None,
     ):
         super().__init__(full_filepath=full_filepath, msg=msg)
         self.dag_id = dag_id
-        self.execution_date = execution_date
+        self.run_id = run_id
         self.is_failure_callback = is_failure_callback
 
 
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index 0b1d9ed..cd43928 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -32,6 +32,8 @@ def has_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Calla
     def requires_access_decorator(func: T):
         @wraps(func)
         def decorated(*args, **kwargs):
+            __tracebackhide__ = True  # Hide from pytest traceback.
+
             appbuilder = current_app.appbuilder
             if not g.user.is_anonymous and not appbuilder.sm.current_user_has_permissions():
                 return (
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 8500f0d..f6f2ed0 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -39,6 +39,7 @@ def action_logging(f: T) -> T:
 
     @functools.wraps(f)
     def wrapper(*args, **kwargs):
+        __tracebackhide__ = True  # Hide from pytest traceback.
 
         with create_session() as session:
             if g.user.is_anonymous:
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index db86783..cc17295 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -18,6 +18,7 @@
 import json
 import textwrap
 import time
+from typing import Any
 from urllib.parse import urlencode
 
 import markdown
@@ -29,6 +30,7 @@ from flask_appbuilder.models.sqla import filters as fab_sqlafilters
 from flask_appbuilder.models.sqla.interface import SQLAInterface
 from pygments import highlight, lexers
 from pygments.formatters import HtmlFormatter
+from sqlalchemy.ext.associationproxy import AssociationProxy
 
 from airflow.models import errors
 from airflow.utils import timezone
@@ -225,8 +227,8 @@ def task_instance_link(attr):
     """Generates a URL to the Graph view for a TaskInstance."""
     dag_id = attr.get('dag_id')
     task_id = attr.get('task_id')
-    execution_date = attr.get('execution_date')
-    url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id, execution_date=execution_date.isoformat())
+    execution_date = attr.get('dag_run.execution_date') or attr.get('execution_date') or timezone.utcnow()
+    url = url_for('Airflow.task', dag_id=dag_id, task_id=task_id)
     url_root = url_for(
         'Airflow.graph', dag_id=dag_id, root=task_id, execution_date=execution_date.isoformat()
     )
@@ -311,7 +313,7 @@ def dag_run_link(attr):
     """Generates a URL to the Graph view for a DagRun."""
     dag_id = attr.get('dag_id')
     run_id = attr.get('run_id')
-    execution_date = attr.get('execution_date')
+    execution_date = attr.get('dag_run.exectuion_date') or attr.get('execution_date')
     url = url_for('Airflow.graph', dag_id=dag_id, run_id=run_id, execution_date=execution_date)
     return Markup('<a href="{url}">{run_id}</a>').format(url=url, run_id=run_id)
 
@@ -456,6 +458,13 @@ class CustomSQLAInterface(SQLAInterface):
                 self.list_columns = {k.lstrip('_'): v for k, v in self.list_columns.items()}
 
         clean_column_names()
+        # Support for AssociationProxy in search and list columns
+        for desc in self.obj.__mapper__.all_orm_descriptors:
+            if not isinstance(desc, AssociationProxy):
+                continue
+            proxy_instance = getattr(self.obj, desc.value_attr)
+            self.list_columns[desc.value_attr] = proxy_instance.remote_attr.prop.columns[0]
+            self.list_properties[desc.value_attr] = proxy_instance.remote_attr.prop
 
     def is_utcdatetime(self, col_name):
         """Check if the datetime is a UTC one."""
@@ -483,6 +492,12 @@ class CustomSQLAInterface(SQLAInterface):
             )
         return False
 
+    def get_col_default(self, col_name: str) -> Any:
+        if col_name not in self.list_columns:
+            # Handle AssociationProxy etc, or anything that isn't a "real" column
+            return None
+        return super().get_col_default(col_name)
+
     filter_converter_class = AirflowFilterConverter
 
 
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 4790c30..7b4d42a 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -18,7 +18,6 @@
 #
 import collections
 import copy
-import itertools
 import json
 import logging
 import math
@@ -807,7 +806,7 @@ class Airflow(AirflowBaseView):
             filter_dag_ids = allowed_dag_ids
 
         running_dag_run_query_result = (
-            session.query(DagRun.dag_id, DagRun.execution_date)
+            session.query(DagRun.dag_id, DagRun.run_id)
             .join(DagModel, DagModel.dag_id == DagRun.dag_id)
             .filter(DagRun.state == State.RUNNING, DagModel.is_active)
         )
@@ -823,7 +822,7 @@ class Airflow(AirflowBaseView):
             running_dag_run_query_result,
             and_(
                 running_dag_run_query_result.c.dag_id == TaskInstance.dag_id,
-                running_dag_run_query_result.c.execution_date == TaskInstance.execution_date,
+                running_dag_run_query_result.c.run_id == TaskInstance.run_id,
             ),
         )
 
@@ -841,14 +840,16 @@ class Airflow(AirflowBaseView):
 
             # Select all task_instances from active dag_runs.
             # If no dag_run is active, return task instances from most recent dag_run.
-            last_task_instance_query_result = session.query(
-                TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state')
-            ).join(
-                last_dag_run,
-                and_(
-                    last_dag_run.c.dag_id == TaskInstance.dag_id,
-                    last_dag_run.c.execution_date == TaskInstance.execution_date,
-                ),
+            last_task_instance_query_result = (
+                session.query(TaskInstance.dag_id.label('dag_id'), TaskInstance.state.label('state'))
+                .join(TaskInstance.dag_run)
+                .join(
+                    last_dag_run,
+                    and_(
+                        last_dag_run.c.dag_id == TaskInstance.dag_id,
+                        last_dag_run.c.execution_date == DagRun.execution_date,
+                    ),
+                )
             )
 
             final_task_instance_query_result = union_all(
@@ -1036,7 +1037,8 @@ class Airflow(AirflowBaseView):
         ]
     )
     @action_logging
-    def rendered_templates(self):
+    @provide_session
+    def rendered_templates(self, session):
         """Get rendered Dag."""
         dag_id = request.args.get('dag_id')
         task_id = request.args.get('task_id')
@@ -1047,11 +1049,13 @@ class Airflow(AirflowBaseView):
 
         logging.info("Retrieving rendered templates.")
         dag = current_app.dag_bag.get_dag(dag_id)
+        dag_run = dag.get_dagrun(execution_date=dttm, session=session)
 
         task = copy.copy(dag.get_task(task_id))
-        ti = models.TaskInstance(task=task, execution_date=dttm)
+        ti = dag_run.get_task_instance(task_id=task.task_id, session=session)
+        ti.refresh_from_task(task)
         try:
-            ti.get_rendered_template_fields()
+            ti.get_rendered_template_fields(session=session)
         except AirflowException as e:
             msg = "Error rendering template: " + escape(e)
             if e.__cause__:
@@ -1125,7 +1129,8 @@ class Airflow(AirflowBaseView):
         logging.info("Retrieving rendered templates.")
         dag = current_app.dag_bag.get_dag(dag_id)
         task = dag.get_task(task_id)
-        ti = models.TaskInstance(task=task, execution_date=dttm)
+        dag_run = dag.get_dagrun(execution_date=dttm)
+        ti = dag_run.get_task_instance(task_id=task.task_id)
 
         pod_spec = None
         try:
@@ -1342,7 +1347,8 @@ class Airflow(AirflowBaseView):
         ]
     )
     @action_logging
-    def task(self):
+    @provide_session
+    def task(self, session):
         """Retrieve task."""
         dag_id = request.args.get('dag_id')
         task_id = request.args.get('task_id')
@@ -1357,30 +1363,52 @@ class Airflow(AirflowBaseView):
             return redirect(url_for('Airflow.index'))
         task = copy.copy(dag.get_task(task_id))
         task.resolve_template_files()
-        ti = TaskInstance(task=task, execution_date=dttm)
-        ti.refresh_from_db()
-
-        ti_attrs = []
-        for attr_name in dir(ti):
-            if not attr_name.startswith('_'):
-                attr = getattr(ti, attr_name)
-                if type(attr) != type(self.task):  # noqa
-                    ti_attrs.append((attr_name, str(attr)))
 
-        task_attrs = []
-        for attr_name in dir(task):
-            if not attr_name.startswith('_'):
-                attr = getattr(task, attr_name)
+        ti = (
+            session.query(TaskInstance)
+            .options(
+                # HACK: Eager-load relationships. This is needed because
+                # multiple properties mis-use provide_session() that destroys
+                # the session object ti is bounded to.
+                joinedload(TaskInstance.queued_by_job, innerjoin=False),
+                joinedload(TaskInstance.trigger, innerjoin=False),
+            )
+            .join(TaskInstance.dag_run)
+            .filter(
+                DagRun.execution_date == dttm,
+                TaskInstance.dag_id == dag_id,
+                TaskInstance.task_id == task_id,
+            )
+            .one()
+        )
+        ti.refresh_from_task(task)
 
-                if type(attr) != type(self.task) and attr_name not in wwwutils.get_attr_renderer():  # noqa
-                    task_attrs.append((attr_name, str(attr)))
+        ti_attrs = [
+            (attr_name, attr)
+            for attr_name, attr in (
+                (attr_name, getattr(ti, attr_name)) for attr_name in dir(ti) if not attr_name.startswith("_")
+            )
+            if not callable(attr)
+        ]
+        ti_attrs = sorted(ti_attrs)
+
+        attr_renderers = wwwutils.get_attr_renderer()
+        task_attrs = [
+            (attr_name, attr)
+            for attr_name, attr in (
+                (attr_name, getattr(task, attr_name))
+                for attr_name in dir(task)
+                if not attr_name.startswith("_") and attr_name not in attr_renderers
+            )
+            if not callable(attr)
+        ]
 
         # Color coding the special attributes that are code
-        special_attrs_rendered = {}
-        for attr_name in wwwutils.get_attr_renderer():
-            if getattr(task, attr_name, None) is not None:
-                source = getattr(task, attr_name)
-                special_attrs_rendered[attr_name] = wwwutils.get_attr_renderer()[attr_name](source)
+        special_attrs_rendered = {
+            attr_name: renderer(getattr(task, attr_name))
+            for attr_name, renderer in attr_renderers.items()
+            if hasattr(task, attr_name)
+        }
 
         no_failed_deps_result = [
             (
@@ -1514,8 +1542,9 @@ class Airflow(AirflowBaseView):
             flash("Only works with the Celery or Kubernetes executors, sorry", "error")
             return redirect(origin)
 
-        ti = models.TaskInstance(task=task, execution_date=execution_date)
-        ti.refresh_from_db()
+        dag_run = dag.get_dagrun(execution_date=execution_date)
+        ti = dag_run.get_task_instance(task_id=task.task_id)
+        ti.refresh_from_task(task)
 
         # Make sure the task instance can be run
         dep_context = DepContext(
@@ -2089,14 +2118,20 @@ class Airflow(AirflowBaseView):
             State.SUCCESS,
         )
 
-    def _get_tree_data(self, dag_runs: Iterable[DagRun], dag: DAG, base_date: DateTime):
+    def _get_tree_data(
+        self,
+        dag_runs: Iterable[DagRun],
+        dag: DAG,
+        base_date: DateTime,
+        session: settings.Session,
+    ):
         """Returns formatted dag_runs for Tree view"""
         dates = sorted(dag_runs.keys())
         min_date = min(dag_runs, default=None)
 
         task_instances = {
             (ti.task_id, ti.execution_date): ti
-            for ti in dag.get_task_instances(start_date=min_date, end_date=base_date)
+            for ti in dag.get_task_instances(start_date=min_date, end_date=base_date, session=session)
         }
 
         expanded = set()
@@ -2239,7 +2274,7 @@ class Airflow(AirflowBaseView):
         else:
             external_log_name = None
 
-        data = self._get_tree_data(dag_runs, dag, base_date)
+        data = self._get_tree_data(dag_runs, dag, base_date, session=session)
 
         # avoid spaces to reduce payload size
         data = htmlsafe_json_dumps(data, separators=(',', ':'))
@@ -2766,23 +2801,22 @@ class Airflow(AirflowBaseView):
         form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data)
         form.execution_date.choices = dt_nr_dr_data['dr_choices']
 
-        tis = [ti for ti in dag.get_task_instances(dttm, dttm) if ti.start_date and ti.state]
-        tis = sorted(tis, key=lambda ti: ti.start_date)
-        ti_fails = list(
-            itertools.chain(
-                *(
-                    (
-                        session.query(TaskFail)
-                        .filter(
-                            TaskFail.dag_id == ti.dag_id,
-                            TaskFail.task_id == ti.task_id,
-                            TaskFail.execution_date == ti.execution_date,
-                        )
-                        .all()
-                    )
-                    for ti in tis
-                )
+        tis = (
+            session.query(TaskInstance)
+            .join(TaskInstance.dag_run)
+            .filter(
+                DagRun.execution_date == dttm,
+                TaskInstance.dag_id == dag_id,
+                TaskInstance.start_date.isnot(None),
+                TaskInstance.state.isnot(None),
             )
+            .order_by(TaskInstance.start_date)
+        )
+
+        ti_fails = (
+            session.query(TaskFail)
+            .join(DagRun, DagRun.execution_date == TaskFail.execution_date)
+            .filter(DagRun.execution_date == dttm, TaskFail.dag_id == dag_id)
         )
 
         tasks = []
@@ -2818,10 +2852,11 @@ class Airflow(AirflowBaseView):
             task_dict['extraLinks'] = task.extra_links
             tasks.append(task_dict)
 
+        task_names = [ti.task_id for ti in tis]
         data = {
-            'taskNames': [ti.task_id for ti in tis],
+            'taskNames': task_names,
             'tasks': tasks,
-            'height': len(tis) * 25 + 25,
+            'height': len(task_names) * 25 + 25,
         }
 
         session.commit()
@@ -2959,9 +2994,8 @@ class Airflow(AirflowBaseView):
                 .limit(num_runs)
                 .all()
             )
-        dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs}
-
-        tree_data = self._get_tree_data(dag_runs, dag, base_date)
+            dag_runs = {dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs}
+            tree_data = self._get_tree_data(dag_runs, dag, base_date, session=session)
 
         # avoid spaces to reduce payload size
         return htmlsafe_json_dumps(tree_data, separators=(',', ':'))
@@ -3949,8 +3983,9 @@ class TaskRescheduleModelView(AirflowModelView):
     list_columns = [
         'id',
         'dag_id',
+        'run_id',
+        'dag_run.execution_date',
         'task_id',
-        'execution_date',
         'try_number',
         'start_date',
         'end_date',
@@ -3958,7 +3993,19 @@ class TaskRescheduleModelView(AirflowModelView):
         'reschedule_date',
     ]
 
-    search_columns = ['dag_id', 'task_id', 'execution_date', 'start_date', 'end_date', 'reschedule_date']
+    label_columns = {
+        'dag_run.execution_date': 'Execution Date',
+    }
+
+    search_columns = [
+        'dag_id',
+        'task_id',
+        'run_id',
+        'execution_date',
+        'start_date',
+        'end_date',
+        'reschedule_date',
+    ]
 
     base_order = ('id', 'desc')
 
@@ -3977,7 +4024,7 @@ class TaskRescheduleModelView(AirflowModelView):
         'task_id': wwwutils.task_instance_link,
         'start_date': wwwutils.datetime_f('start_date'),
         'end_date': wwwutils.datetime_f('end_date'),
-        'execution_date': wwwutils.datetime_f('execution_date'),
+        'dag_run.execution_date': wwwutils.datetime_f('dag_run.execution_date'),
         'reschedule_date': wwwutils.datetime_f('reschedule_date'),
         'duration': duration_f,
     }
@@ -4013,7 +4060,8 @@ class TaskInstanceModelView(AirflowModelView):
         'state',
         'dag_id',
         'task_id',
-        'execution_date',
+        'run_id',
+        'dag_run.execution_date',
         'operator',
         'start_date',
         'end_date',
@@ -4035,10 +4083,15 @@ class TaskInstanceModelView(AirflowModelView):
         item for item in list_columns if item not in ['try_number', 'log_url', 'external_executor_id']
     ]
 
+    label_columns = {
+        'dag_run.execution_date': 'Execution Date',
+    }
+
     search_columns = [
         'state',
         'dag_id',
         'task_id',
+        'run_id',
         'execution_date',
         'hostname',
         'queue',
@@ -4051,9 +4104,6 @@ class TaskInstanceModelView(AirflowModelView):
 
     edit_columns = [
         'state',
-        'dag_id',
-        'task_id',
-        'execution_date',
         'start_date',
         'end_date',
     ]
@@ -4084,9 +4134,10 @@ class TaskInstanceModelView(AirflowModelView):
     formatters_columns = {
         'log_url': log_url_formatter,
         'task_id': wwwutils.task_instance_link,
+        'run_id': wwwutils.dag_run_link,
         'hostname': wwwutils.nobr_f('hostname'),
         'state': wwwutils.state_f,
-        'execution_date': wwwutils.datetime_f('execution_date'),
+        'dag_run.execution_date': wwwutils.datetime_f('dag_run.execution_date'),
         'start_date': wwwutils.datetime_f('start_date'),
         'end_date': wwwutils.datetime_f('end_date'),
         'queued_dttm': wwwutils.datetime_f('queued_dttm'),
diff --git a/docs/apache-airflow/concepts/scheduler.rst b/docs/apache-airflow/concepts/scheduler.rst
index 0a1079e..3a54b02 100644
--- a/docs/apache-airflow/concepts/scheduler.rst
+++ b/docs/apache-airflow/concepts/scheduler.rst
@@ -174,17 +174,6 @@ The following config settings can be used to control aspects of the Scheduler HA
   this, so this should be set to match the same period as your statsd roll-up
   period.
 
-- :ref:`config:scheduler__clean_tis_without_dagrun_interval`
-
-  How often should each scheduler run a check to "clean up" TaskInstance rows
-  that are found to no longer have a matching DagRun row.
-
-  In normal operation the scheduler won't do this, it is only possible to do
-  this by deleting rows via the UI, or directly in the DB. You can set this
-  lower if this check is not important to you -- tasks will be left in what
-  ever state they are until the cleanup happens, at which point they will be
-  set to failed.
-
 - :ref:`config:scheduler__orphaned_tasks_check_interval`
 
   How often (in seconds) should the scheduler check for orphaned tasks or dead
diff --git a/docs/apache-airflow/logging-monitoring/metrics.rst b/docs/apache-airflow/logging-monitoring/metrics.rst
index 8e5d46b..c1bc249 100644
--- a/docs/apache-airflow/logging-monitoring/metrics.rst
+++ b/docs/apache-airflow/logging-monitoring/metrics.rst
@@ -123,7 +123,6 @@ Name                                                Description
 ``dag_processing.total_parse_time``                 Seconds taken to scan and import all DAG files once
 ``dag_processing.last_run.seconds_ago.<dag_file>``  Seconds since ``<dag_file>`` was last processed
 ``dag_processing.processor_timeouts``               Number of file processors that have been killed due to taking too long
-``scheduler.tasks.without_dagrun``                  Number of tasks without DagRuns or with DagRuns not in Running state
 ``scheduler.tasks.running``                         Number of tasks running in executor
 ``scheduler.tasks.starving``                        Number of tasks that cannot be scheduled because of no open slot in pool
 ``scheduler.tasks.executable``                      Number of tasks that are ready for execution (set to queued)
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index 1f387b6..052d61a 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -23,7 +23,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
 | Revision ID                    | Revises ID       | Airflow Version | Description                                                                           |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
-| ``142555e44c17`` (head)        | ``54bebd308c5f`` |                 | Add ``data_interval_start`` and ``data_interval_end`` to ``DagRun``                   |
+| ``7b2661a43ba3`` (head)        | ``142555e44c17`` |                 | Change TaskInstance and TaskReschedule tables from execution_date to run_id.          |
++--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
+| ``142555e44c17``               | ``54bebd308c5f`` |                 | Add ``data_interval_start`` and ``data_interval_end`` to ``DagRun``                   |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
 | ``54bebd308c5f``               | ``30867afad44a`` |                 | Adds ``trigger`` table and deferrable operator columns to task instance               |
 +--------------------------------+------------------+-----------------+---------------------------------------------------------------------------------------+
diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py
index 604e5f2..52d7f3a 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator.py
@@ -34,7 +34,7 @@ from kubernetes.client.rest import ApiException
 from airflow.exceptions import AirflowException
 from airflow.kubernetes import kube_client
 from airflow.kubernetes.secret import Secret
-from airflow.models import DAG, TaskInstance
+from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
 from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher
 from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -47,7 +47,9 @@ def create_context(task):
     dag = DAG(dag_id="dag")
     tzinfo = pendulum.timezone("Europe/Amsterdam")
     execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
-    task_instance = TaskInstance(task=task, execution_date=execution_date)
+    dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date)
+    task_instance = TaskInstance(task=task)
+    task_instance.dag_run = dag_run
     task_instance.xcom_push = mock.Mock()
     return {
         "dag": dag,
diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
index 0b0eabe..6834008 100644
--- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
+++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
@@ -34,7 +34,7 @@ from airflow.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv
 from airflow.kubernetes.secret import Secret
 from airflow.kubernetes.volume import Volume
 from airflow.kubernetes.volume_mount import VolumeMount
-from airflow.models import DAG, TaskInstance
+from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
 from airflow.providers.cncf.kubernetes.utils.pod_launcher import PodLauncher
 from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -50,7 +50,9 @@ def create_context(task):
     dag = DAG(dag_id="dag")
     tzinfo = pendulum.timezone("Europe/Amsterdam")
     execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
-    task_instance = TaskInstance(task=task, execution_date=execution_date)
+    dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date)
+    task_instance = TaskInstance(task=task)
+    task_instance.dag_run = dag_run
     task_instance.xcom_push = mock.Mock()
     return {
         "dag": dag,
diff --git a/tests/api/common/experimental/test_delete_dag.py b/tests/api/common/experimental/test_delete_dag.py
index 58bcd37..5984cd2 100644
--- a/tests/api/common/experimental/test_delete_dag.py
+++ b/tests/api/common/experimental/test_delete_dag.py
@@ -16,15 +16,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
 
 import pytest
 
-from airflow import models, settings
+from airflow import models
 from airflow.api.common.experimental.delete_dag import delete_dag
 from airflow.exceptions import AirflowException, DagNotFound
 from airflow.operators.dummy import DummyOperator
-from airflow.utils import timezone
 from airflow.utils.dates import days_ago
 from airflow.utils.session import create_session
 from airflow.utils.state import State
@@ -40,15 +38,7 @@ TR = models.taskreschedule.TaskReschedule
 IE = models.ImportError
 
 
-class TestDeleteDAGCatchError(unittest.TestCase):
-    def setUp(self):
-        self.dagbag = models.DagBag(include_examples=True)
-        self.dag_id = 'example_bash_operator'
-        self.dag = self.dagbag.dags[self.dag_id]
-
-    def tearDown(self):
-        self.dag.clear()
-
+class TestDeleteDAGCatchError:
     def test_delete_dag_non_existent_dag(self):
         with pytest.raises(DagNotFound):
             delete_dag("non-existent DAG")
@@ -63,21 +53,17 @@ class TestDeleteDAGErrorsOnRunningTI:
         clear_db_dags()
         clear_db_runs()
 
-    def test_delete_dag_running_taskinstances(self, create_dummy_dag):
+    def test_delete_dag_running_taskinstances(self, session, create_task_instance):
         dag_id = 'test-dag'
-        _, task = create_dummy_dag(dag_id)
+        ti = create_task_instance(dag_id=dag_id, session=session)
 
-        ti = TI(task, execution_date=timezone.utcnow())
-        ti.refresh_from_db()
-        session = settings.Session()
         ti.state = State.RUNNING
-        session.merge(ti)
         session.commit()
         with pytest.raises(AirflowException):
             delete_dag(dag_id)
 
 
-class TestDeleteDAGSuccessfulDelete(unittest.TestCase):
+class TestDeleteDAGSuccessfulDelete:
     dag_file_path = "/usr/local/airflow/dags/test_dag_8.py"
     key = "test_dag_id"
 
@@ -94,8 +80,10 @@ class TestDeleteDAGSuccessfulDelete(unittest.TestCase):
         test_date = days_ago(1)
         with create_session() as session:
             session.add(DM(dag_id=self.key, fileloc=self.dag_file_path, is_subdag=for_sub_dag))
-            session.add(DR(dag_id=self.key, run_type=DagRunType.MANUAL))
-            session.add(TI(task=task, execution_date=test_date, state=State.SUCCESS))
+            dr = DR(dag_id=self.key, run_type=DagRunType.MANUAL, run_id="test", execution_date=test_date)
+            ti = TI(task=task, state=State.SUCCESS)
+            ti.dag_run = dr
+            session.add_all((dr, ti))
             # flush to ensure task instance if written before
             # task reschedule because of FK constraint
             session.flush()
@@ -111,8 +99,8 @@ class TestDeleteDAGSuccessfulDelete(unittest.TestCase):
             session.add(TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date))
             session.add(
                 TR(
-                    task=task,
-                    execution_date=test_date,
+                    task=ti.task,
+                    run_id=ti.run_id,
                     start_date=test_date,
                     end_date=test_date,
                     try_number=1,
@@ -127,7 +115,7 @@ class TestDeleteDAGSuccessfulDelete(unittest.TestCase):
                 )
             )
 
-    def tearDown(self):
+    def teardown_method(self):
         with create_session() as session:
             session.query(TR).filter(TR.dag_id == self.key).delete()
             session.query(TF).filter(TF.dag_id == self.key).delete()
diff --git a/tests/api/common/experimental/test_mark_tasks.py b/tests/api/common/experimental/test_mark_tasks.py
index 49008d3..e43ac4a 100644
--- a/tests/api/common/experimental/test_mark_tasks.py
+++ b/tests/api/common/experimental/test_mark_tasks.py
@@ -20,6 +20,7 @@ import unittest
 from datetime import timedelta
 
 import pytest
+from sqlalchemy.orm import eagerload
 
 from airflow import models
 from airflow.api.common.experimental.mark_tasks import (
@@ -40,11 +41,20 @@ from tests.test_utils.db import clear_db_runs
 DEV_NULL = "/dev/null"
 
 
-class TestMarkTasks(unittest.TestCase):
+@pytest.fixture(scope="module")
+def dagbag():
+    from airflow.models.dagbag import DagBag
+
+    # Ensure the DAGs we are looking at from the DB are up-to-date
+    non_serialized_dagbag = DagBag(read_dags_from_db=False, include_examples=False)
+    non_serialized_dagbag.sync_to_db()
+    return DagBag(read_dags_from_db=True)
+
+
+class TestMarkTasks:
+    @pytest.fixture(scope="class", autouse=True, name="create_dags")
     @classmethod
-    def setUpClass(cls):
-        models.DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
-        dagbag = models.DagBag(include_examples=False, read_dags_from_db=True)
+    def create_dags(cls, dagbag):
         cls.dag1 = dagbag.get_dag('miscellaneous_test_dag')
         cls.dag2 = dagbag.get_dag('example_subdag_operator')
         cls.dag3 = dagbag.get_dag('example_trigger_target_dag')
@@ -56,7 +66,9 @@ class TestMarkTasks(unittest.TestCase):
             start_date3 + timedelta(days=2),
         ]
 
-    def setUp(self):
+    @pytest.fixture(autouse=True)
+    def setup(self):
+
         clear_db_runs()
         drs = _create_dagruns(
             self.dag1, self.execution_dates, state=State.RUNNING, run_type=DagRunType.SCHEDULED
@@ -77,24 +89,35 @@ class TestMarkTasks(unittest.TestCase):
         for dr in drs:
             dr.dag = self.dag3
 
-    def tearDown(self):
+        yield
+
         clear_db_runs()
 
     @staticmethod
     def snapshot_state(dag, execution_dates):
         TI = models.TaskInstance
+        DR = models.DagRun
         with create_session() as session:
             return (
                 session.query(TI)
-                .filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates))
+                .join(TI.dag_run)
+                .options(eagerload(TI.dag_run))
+                .filter(TI.dag_id == dag.dag_id, DR.execution_date.in_(execution_dates))
                 .all()
             )
 
     @provide_session
     def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None):
         TI = models.TaskInstance
-
-        tis = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all()
+        DR = models.DagRun
+
+        tis = (
+            session.query(TI)
+            .join(TI.dag_run)
+            .options(eagerload(TI.dag_run))
+            .filter(TI.dag_id == dag.dag_id, DR.execution_date.in_(execution_dates))
+            .all()
+        )
 
         assert len(tis) > 0
 
diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py
index ead538b..cc92733 100644
--- a/tests/api_connexion/conftest.py
+++ b/tests/api_connexion/conftest.py
@@ -40,3 +40,11 @@ def session():
 
     with create_session() as session:
         yield session
+
+
+@pytest.fixture(scope="session")
+def dagbag():
+    from airflow.models import DagBag
+
+    DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
+    return DagBag(include_examples=True, read_dags_from_db=True)
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 9c5f593..17a2553 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -122,6 +122,7 @@ class TestDagEndpoint:
                 fileloc=f"/tmp/dag_{num}.py",
                 schedule_interval="2 2 * * *",
                 is_active=True,
+                is_paused=False,
             )
             session.add(dag_model)
 
@@ -162,6 +163,7 @@ class TestGetDag(TestDagEndpoint):
             dag_id="TEST_DAG_1",
             fileloc="/tmp/dag_1.py",
             schedule_interval=None,
+            is_paused=False,
         )
         session.add(dag_model)
         session.commit()
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index f1025e6..65daae8 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -16,15 +16,11 @@
 # under the License.
 
 import pytest
-from parameterized import parameterized
 
-from airflow import DAG
 from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models import Log, TaskInstance
-from airflow.operators.dummy import DummyOperator
+from airflow.models import Log
 from airflow.security import permissions
 from airflow.utils import timezone
-from airflow.utils.session import provide_session
 from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_logs
@@ -47,43 +43,56 @@ def configured_app(minimal_app_for_api):
     delete_user(app, username="test_no_permissions")  # type: ignore
 
 
+@pytest.fixture
+def task_instance(session, create_task_instance, request):
+    return create_task_instance(
+        session=session,
+        dag_id="TEST_DAG_ID",
+        task_id="TEST_TASK_ID",
+        execution_date=request.instance.default_time,
+    )
+
+
+@pytest.fixture()
+def log_model(create_log_model, request):
+    return create_log_model(
+        event="TEST_EVENT",
+        when=request.instance.default_time,
+    )
+
+
+@pytest.fixture
+def create_log_model(create_task_instance, task_instance, session, request):
+    def maker(event, when, **kwargs):
+        log_model = Log(
+            event=event,
+            task_instance=task_instance,
+            **kwargs,
+        )
+        log_model.dttm = when
+
+        session.add(log_model)
+        session.flush()
+        return log_model
+
+    return maker
+
+
 class TestEventLogEndpoint:
     @pytest.fixture(autouse=True)
     def setup_attrs(self, configured_app) -> None:
         self.app = configured_app
         self.client = self.app.test_client()  # type:ignore
         clear_db_logs()
-        self.default_time = "2020-06-10T20:00:00+00:00"
-        self.default_time_2 = '2020-06-11T07:00:00+00:00'
+        self.default_time = timezone.parse("2020-06-10T20:00:00+00:00")
+        self.default_time_2 = timezone.parse('2020-06-11T07:00:00+00:00')
 
     def teardown_method(self) -> None:
         clear_db_logs()
 
-    def _create_task_instance(self):
-        dag = DAG(
-            'TEST_DAG_ID',
-            start_date=timezone.parse(self.default_time),
-            end_date=timezone.parse(self.default_time),
-        )
-        op1 = DummyOperator(
-            task_id="TEST_TASK_ID",
-            owner="airflow",
-        )
-        dag.add_task(op1)
-        ti = TaskInstance(task=op1, execution_date=timezone.parse(self.default_time))
-        return ti
-
 
 class TestGetEventLog(TestEventLogEndpoint):
-    @provide_session
-    def test_should_respond_200(self, session):
-        log_model = Log(
-            event='TEST_EVENT',
-            task_instance=self._create_task_instance(),
-        )
-        log_model.dttm = timezone.parse(self.default_time)
-        session.add(log_model)
-        session.commit()
+    def test_should_respond_200(self, log_model):
         event_log_id = log_model.id
         response = self.client.get(
             f"/api/v1/eventLogs/{event_log_id}", environ_overrides={'REMOTE_USER': "test"}
@@ -94,9 +103,9 @@ class TestGetEventLog(TestEventLogEndpoint):
             "event": "TEST_EVENT",
             "dag_id": "TEST_DAG_ID",
             "task_id": "TEST_TASK_ID",
-            "execution_date": self.default_time,
+            "execution_date": self.default_time.isoformat(),
             "owner": 'airflow',
-            "when": self.default_time,
+            "when": self.default_time.isoformat(),
             "extra": None,
         }
 
@@ -110,15 +119,7 @@ class TestGetEventLog(TestEventLogEndpoint):
             'type': EXCEPTIONS_LINK_MAP[404],
         } == response.json
 
-    @provide_session
-    def test_should_raises_401_unauthenticated(self, session):
-        log_model = Log(
-            event='TEST_EVENT',
-            task_instance=self._create_task_instance(),
-        )
-        log_model.dttm = timezone.parse(self.default_time)
-        session.add(log_model)
-        session.commit()
+    def test_should_raises_401_unauthenticated(self, log_model):
         event_log_id = log_model.id
 
         response = self.client.get(f"/api/v1/eventLogs/{event_log_id}")
@@ -133,21 +134,14 @@ class TestGetEventLog(TestEventLogEndpoint):
 
 
 class TestGetEventLogs(TestEventLogEndpoint):
-    def test_should_respond_200(self, session):
-        log_model_1 = Log(
-            event='TEST_EVENT_1',
-            task_instance=self._create_task_instance(),
-        )
-        log_model_2 = Log(
-            event='TEST_EVENT_2',
-            task_instance=self._create_task_instance(),
-        )
+    def test_should_respond_200(self, session, create_log_model):
+        log_model_1 = create_log_model(event='TEST_EVENT_1', when=self.default_time)
+        log_model_2 = create_log_model(event='TEST_EVENT_2', when=self.default_time_2)
         log_model_3 = Log(event="cli_scheduler", owner='root', extra='{"host_name": "e24b454f002a"}')
-        log_model_1.dttm = timezone.parse(self.default_time)
-        log_model_2.dttm = timezone.parse(self.default_time_2)
-        log_model_3.dttm = timezone.parse(self.default_time_2)
-        session.add_all([log_model_1, log_model_2, log_model_3])
-        session.commit()
+        log_model_3.dttm = self.default_time_2
+
+        session.add(log_model_3)
+        session.flush()
         response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"})
         assert response.status_code == 200
         assert response.json == {
@@ -157,9 +151,9 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "event": "TEST_EVENT_1",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'airflow',
-                    "when": self.default_time,
+                    "when": self.default_time.isoformat(),
                     "extra": None,
                 },
                 {
@@ -167,9 +161,9 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "event": "TEST_EVENT_2",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'airflow',
-                    "when": self.default_time_2,
+                    "when": self.default_time_2.isoformat(),
                     "extra": None,
                 },
                 {
@@ -179,25 +173,20 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "task_id": None,
                     "execution_date": None,
                     "owner": 'root',
-                    "when": self.default_time_2,
+                    "when": self.default_time_2.isoformat(),
                     "extra": '{"host_name": "e24b454f002a"}',
                 },
             ],
             "total_entries": 3,
         }
 
-    def test_order_eventlogs_by_owner(self, session):
-        log_model_1 = Log(
-            event='TEST_EVENT_1',
-            task_instance=self._create_task_instance(),
-        )
-        log_model_2 = Log(event='TEST_EVENT_2', task_instance=self._create_task_instance(), owner="zsh")
+    def test_order_eventlogs_by_owner(self, create_log_model, session):
+        log_model_1 = create_log_model(event="TEST_EVENT_1", when=self.default_time)
+        log_model_2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2, owner='zsh')
         log_model_3 = Log(event="cli_scheduler", owner='root', extra='{"host_name": "e24b454f002a"}')
-        log_model_1.dttm = timezone.parse(self.default_time)
-        log_model_2.dttm = timezone.parse(self.default_time_2)
-        log_model_3.dttm = timezone.parse(self.default_time_2)
-        session.add_all([log_model_1, log_model_2, log_model_3])
-        session.commit()
+        log_model_3.dttm = self.default_time_2
+        session.add(log_model_3)
+        session.flush()
         response = self.client.get(
             "/api/v1/eventLogs?order_by=-owner", environ_overrides={'REMOTE_USER': "test"}
         )
@@ -209,9 +198,9 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "event": "TEST_EVENT_2",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'zsh',  # Order by name, sort order is descending(-)
-                    "when": self.default_time_2,
+                    "when": self.default_time_2.isoformat(),
                     "extra": None,
                 },
                 {
@@ -221,7 +210,7 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "task_id": None,
                     "execution_date": None,
                     "owner": 'root',
-                    "when": self.default_time_2,
+                    "when": self.default_time_2.isoformat(),
                     "extra": '{"host_name": "e24b454f002a"}',
                 },
                 {
@@ -229,37 +218,24 @@ class TestGetEventLogs(TestEventLogEndpoint):
                     "event": "TEST_EVENT_1",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'airflow',
-                    "when": self.default_time,
+                    "when": self.default_time.isoformat(),
                     "extra": None,
                 },
             ],
             "total_entries": 3,
         }
 
-    @provide_session
-    def test_should_raises_401_unauthenticated(self, session):
-        log_model_1 = Log(
-            event='TEST_EVENT_1',
-            task_instance=self._create_task_instance(),
-        )
-        log_model_2 = Log(
-            event='TEST_EVENT_2',
-            task_instance=self._create_task_instance(),
-        )
-        log_model_1.dttm = timezone.parse(self.default_time)
-        log_model_2.dttm = timezone.parse(self.default_time_2)
-        session.add_all([log_model_1, log_model_2])
-        session.commit()
-
+    def test_should_raises_401_unauthenticated(self, log_model):
         response = self.client.get("/api/v1/eventLogs")
 
         assert_401(response)
 
 
 class TestGetEventLogPagination(TestEventLogEndpoint):
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        ("url", "expected_events"),
         [
             ("api/v1/eventLogs?limit=1", ["TEST_EVENT_1"]),
             ("api/v1/eventLogs?limit=2", ["TEST_EVENT_1", "TEST_EVENT_2"]),
@@ -294,11 +270,10 @@ class TestGetEventLogPagination(TestEventLogEndpoint):
                 "api/v1/eventLogs?limit=2&offset=2",
                 ["TEST_EVENT_3", "TEST_EVENT_4"],
             ),
-        ]
+        ],
     )
-    @provide_session
-    def test_handle_limit_and_offset(self, url, expected_events, session):
-        log_models = self._create_event_logs(10)
+    def test_handle_limit_and_offset(self, url, expected_events, task_instance, session):
+        log_models = self._create_event_logs(task_instance, 10)
         session.add_all(log_models)
         session.commit()
 
@@ -309,11 +284,10 @@ class TestGetEventLogPagination(TestEventLogEndpoint):
         events = [event_log["event"] for event_log in response.json["event_logs"]]
         assert events == expected_events
 
-    @provide_session
-    def test_should_respect_page_size_limit_default(self, session):
-        log_models = self._create_event_logs(200)
+    def test_should_respect_page_size_limit_default(self, task_instance, session):
+        log_models = self._create_event_logs(task_instance, 200)
         session.add_all(log_models)
-        session.commit()
+        session.flush()
 
         response = self.client.get("/api/v1/eventLogs", environ_overrides={'REMOTE_USER': "test"})
         assert response.status_code == 200
@@ -321,10 +295,10 @@ class TestGetEventLogPagination(TestEventLogEndpoint):
         assert response.json["total_entries"] == 200
         assert len(response.json["event_logs"]) == 100  # default 100
 
-    def test_should_raise_400_for_invalid_order_by_name(self, session):
-        log_models = self._create_event_logs(200)
+    def test_should_raise_400_for_invalid_order_by_name(self, task_instance, session):
+        log_models = self._create_event_logs(task_instance, 200)
         session.add_all(log_models)
-        session.commit()
+        session.flush()
 
         response = self.client.get(
             "/api/v1/eventLogs?order_by=invalid", environ_overrides={'REMOTE_USER': "test"}
@@ -333,19 +307,15 @@ class TestGetEventLogPagination(TestEventLogEndpoint):
         msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model"
         assert response.json['detail'] == msg
 
-    @provide_session
     @conf_vars({("api", "maximum_page_limit"): "150"})
-    def test_should_return_conf_max_if_req_max_above_conf(self, session):
-        log_models = self._create_event_logs(200)
+    def test_should_return_conf_max_if_req_max_above_conf(self, task_instance, session):
+        log_models = self._create_event_logs(task_instance, 200)
         session.add_all(log_models)
-        session.commit()
+        session.flush()
 
         response = self.client.get("/api/v1/eventLogs?limit=180", environ_overrides={'REMOTE_USER': "test"})
         assert response.status_code == 200
         assert len(response.json['event_logs']) == 150
 
-    def _create_event_logs(self, count):
-        return [
-            Log(event="TEST_EVENT_" + str(i), task_instance=self._create_task_instance())
-            for i in range(1, count + 1)
-        ]
+    def _create_event_logs(self, task_instance, count):
+        return [Log(event="TEST_EVENT_" + str(i), task_instance=task_instance) for i in range(1, count + 1)]
diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py
index 15caf19..87963c0 100644
--- a/tests/api_connexion/endpoints/test_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_log_endpoint.py
@@ -26,11 +26,9 @@ from itsdangerous.url_safe import URLSafeSerializer
 from airflow import DAG
 from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
 from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
-from airflow.models import DagRun, TaskInstance
 from airflow.operators.dummy import DummyOperator
 from airflow.security import permissions
 from airflow.utils import timezone
-from airflow.utils.session import create_session
 from airflow.utils.types import DagRunType
 from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
 from tests.test_utils.db import clear_db_runs
@@ -66,24 +64,25 @@ class TestGetLog:
     default_time = "2020-06-10T20:00:00+00:00"
 
     @pytest.fixture(autouse=True)
-    def setup_attrs(self, configured_app, configure_loggers) -> None:
+    def setup_attrs(self, configured_app, configure_loggers, dag_maker, session) -> None:
         self.app = configured_app
         self.client = self.app.test_client()
         # Make sure that the configure_logging is not cached
         self.old_modules = dict(sys.modules)
-        self._prepare_db()
 
-    def _create_dagrun(self, session):
-        dagrun_model = DagRun(
-            dag_id=self.DAG_ID,
+        with dag_maker(self.DAG_ID, start_date=timezone.parse(self.default_time), session=session) as dag:
+            DummyOperator(task_id=self.TASK_ID)
+        dr = dag_maker.create_dagrun(
             run_id='TEST_DAG_RUN_ID',
             run_type=DagRunType.MANUAL,
             execution_date=timezone.parse(self.default_time),
             start_date=timezone.parse(self.default_time),
-            external_trigger=True,
         )
-        session.add(dagrun_model)
-        session.commit()
+
+        configured_app.dag_bag.bag_dag(dag, root_dag=dag)
+
+        self.ti = dr.task_instances[0]
+        self.ti.try_number = 1
 
     @pytest.fixture
     def configure_loggers(self, tmp_path):
@@ -109,25 +108,10 @@ class TestGetLog:
 
         logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
 
-    def _prepare_db(self):
-        dagbag = self.app.dag_bag
-        dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time))
-        dag.sync_to_db()
-        dagbag.dags.pop(self.DAG_ID, None)
-        dagbag.bag_dag(dag=dag, root_dag=dag)
-        with create_session() as session:
-            self.ti = TaskInstance(
-                task=DummyOperator(task_id=self.TASK_ID, dag=dag),
-                execution_date=timezone.parse(self.default_time),
-            )
-            self.ti.try_number = 1
-            session.merge(self.ti)
-
     def teardown_method(self):
         clear_db_runs()
 
     def test_should_respond_200_json(self, session):
-        self._create_dagrun(session)
         key = self.app.config["SECRET_KEY"]
         serializer = URLSafeSerializer(key)
         token = serializer.dumps({"download_logs": False})
@@ -149,7 +133,6 @@ class TestGetLog:
         assert 200 == response.status_code
 
     def test_should_respond_200_text_plain(self, session):
-        self._create_dagrun(session)
         key = self.app.config["SECRET_KEY"]
         serializer = URLSafeSerializer(key)
         token = serializer.dumps({"download_logs": True})
@@ -170,8 +153,6 @@ class TestGetLog:
         )
 
     def test_get_logs_of_removed_task(self, session):
-        self._create_dagrun(session)
-
         # Recreate DAG without tasks
         dagbag = self.app.dag_bag
         dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time))
@@ -198,7 +179,6 @@ class TestGetLog:
         )
 
     def test_get_logs_response_with_ti_equal_to_none(self, session):
-        self._create_dagrun(session)
         key = self.app.config["SECRET_KEY"]
         serializer = URLSafeSerializer(key)
         token = serializer.dumps({"download_logs": True})
@@ -208,11 +188,15 @@ class TestGetLog:
             f"taskInstances/Invalid-Task-ID/logs/1?token={token}",
             environ_overrides={'REMOTE_USER': "test"},
         )
-        assert response.status_code == 400
-        assert response.json['detail'] == "Task instance did not exist in the DB"
+        assert response.status_code == 404
+        assert response.json == {
+            'detail': None,
+            'status': 404,
+            'title': "TaskInstance not found",
+            'type': EXCEPTIONS_LINK_MAP[404],
+        }
 
     def test_get_logs_with_metadata_as_download_large_file(self, session):
-        self._create_dagrun(session)
         with mock.patch("airflow.utils.log.file_task_handler.FileTaskHandler.read") as read_mock:
             first_return = ([[('', '1st line')]], [{}])
             second_return = ([[('', '2nd line')]], [{'end_of_log': False}])
@@ -251,7 +235,6 @@ class TestGetLog:
         assert 'Task log handler does not support read logs.' in response.data.decode('utf-8')
 
     def test_bad_signature_raises(self, session):
-        self._create_dagrun(session)
         token = {"download_logs": False}
 
         response = self.client.get(
@@ -269,15 +252,16 @@ class TestGetLog:
 
     def test_raises_404_for_invalid_dag_run_id(self):
         response = self.client.get(
-            f"api/v1/dags/{self.DAG_ID}/dagRuns/TEST_DAG_RUN/"  # invalid dagrun_id
+            f"api/v1/dags/{self.DAG_ID}/dagRuns/NO_DAG_RUN/"  # invalid dagrun_id
             f"taskInstances/{self.TASK_ID}/logs/1?",
             headers={'Accept': 'application/json'},
             environ_overrides={'REMOTE_USER': "test"},
         )
+        assert response.status_code == 404
         assert response.json == {
             'detail': None,
             'status': 404,
-            'title': "DAG Run not found",
+            'title': "TaskInstance not found",
             'type': EXCEPTIONS_LINK_MAP[404],
         }
 
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index e359cd4..7696b76 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -19,8 +19,9 @@ from unittest import mock
 
 import pytest
 from parameterized import parameterized
+from sqlalchemy.orm import contains_eager
 
-from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance
+from airflow.models import DagRun, SlaMiss, TaskInstance
 from airflow.security import permissions
 from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
@@ -58,7 +59,7 @@ def configured_app(minimal_app_for_api):
 
 class TestTaskInstanceEndpoint:
     @pytest.fixture(autouse=True)
-    def setup_attrs(self, configured_app) -> None:
+    def setup_attrs(self, configured_app, dagbag) -> None:
         self.default_time = DEFAULT_DATETIME_1
         self.ti_init = {
             "execution_date": self.default_time,
@@ -77,15 +78,13 @@ class TestTaskInstanceEndpoint:
         self.client = self.app.test_client()  # type:ignore
         clear_db_runs()
         clear_db_sla_miss()
-        DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
-        self.dagbag = DagBag(include_examples=True, read_dags_from_db=True)
+        self.dagbag = dagbag
 
     def create_task_instances(
         self,
         session,
         dag_id: str = "example_python_operator",
         update_extras: bool = True,
-        single_dag_run: bool = True,
         task_instances=None,
         dag_run_state=State.RUNNING,
     ):
@@ -97,6 +96,10 @@ class TestTaskInstanceEndpoint:
         if task_instances is not None:
             counter = min(len(task_instances), counter)
 
+        run_id = "TEST_DAG_RUN_ID"
+        execution_date = self.ti_init.pop("execution_date", self.default_time)
+        dr = None
+
         for i in range(counter):
             if task_instances is None:
                 pass
@@ -104,31 +107,28 @@ class TestTaskInstanceEndpoint:
                 self.ti_extras.update(task_instances[i])
             else:
                 self.ti_init.update(task_instances[i])
-            ti = TaskInstance(task=tasks[i], **self.ti_init)
 
-            for key, value in self.ti_extras.items():
-                setattr(ti, key, value)
-            session.add(ti)
+            if "execution_date" in self.ti_init:
+                run_id = f"TEST_DAG_RUN_ID_{i}"
+                execution_date = self.ti_init.pop("execution_date")
+                dr = None
 
-            if single_dag_run is False:
+            if not dr:
                 dr = DagRun(
+                    run_id=run_id,
                     dag_id=dag_id,
-                    run_id=f"TEST_DAG_RUN_ID_{i}",
-                    execution_date=self.ti_init["execution_date"],
-                    run_type=DagRunType.MANUAL.value,
+                    execution_date=execution_date,
+                    run_type=DagRunType.MANUAL,
                     state=dag_run_state,
                 )
                 session.add(dr)
+            ti = TaskInstance(task=tasks[i], **self.ti_init)
+            ti.dag_run = dr
+
+            for key, value in self.ti_extras.items():
+                setattr(ti, key, value)
+            session.add(ti)
 
-        if single_dag_run:
-            dr = DagRun(
-                dag_id=dag_id,
-                run_id="TEST_DAG_RUN_ID",
-                execution_date=self.default_time,
-                run_type=DagRunType.MANUAL.value,
-                state=dag_run_state,
-            )
-            session.add(dr)
         session.commit()
 
 
@@ -274,7 +274,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ],
                 False,
                 (
-                    "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/"
+                    "/api/v1/dags/example_python_operator/dagRuns/~/"
                     f"taskInstances?execution_date_lte={DEFAULT_DATETIME_STR_1}"
                 ),
                 1,
@@ -288,26 +288,12 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ],
                 True,
                 (
-                    "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+                    "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances"
                     f"?start_date_gte={DEFAULT_DATETIME_STR_1}&start_date_lte={DEFAULT_DATETIME_STR_2}"
                 ),
                 2,
             ),
             (
-                "test start date filter with ~",
-                [
-                    {"start_date": DEFAULT_DATETIME_1},
-                    {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
-                    {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
-                ],
-                True,
-                (
-                    "/api/v1/dags/~/dagRuns/~/taskInstances?start_date_gte"
-                    f"={DEFAULT_DATETIME_STR_1}&start_date_lte={DEFAULT_DATETIME_STR_2}"
-                ),
-                2,
-            ),
-            (
                 "test end date filter",
                 [
                     {"end_date": DEFAULT_DATETIME_1},
@@ -316,26 +302,12 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
                 ],
                 True,
                 (
-                    "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances?"
+                    "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances?"
                     f"end_date_gte={DEFAULT_DATETIME_STR_1}&end_date_lte={DEFAULT_DATETIME_STR_2}"
                 ),
                 2,
             ),
             (
-                "test end date filter ~",
-                [
-                    {"end_date": DEFAULT_DATETIME_1},
-                    {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
-                    {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
-                ],
-                True,
-                (
-                    "/api/v1/dags/~/dagRuns/~/taskInstances?end_date_gte"
-                    f"={DEFAULT_DATETIME_STR_1}&end_date_lte={DEFAULT_DATETIME_STR_2}"
-                ),
-                2,
-            ),
-            (
                 "test duration filter",
                 [
                     {"duration": 100},
@@ -428,6 +400,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
     )
     @provide_session
     def test_should_respond_200(self, _, task_instances, update_extras, url, expected_ti, session):
+
         self.create_task_instances(
             session,
             update_extras=update_extras,
@@ -476,7 +449,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"queue": "test_queue_3"},
                 ],
                 True,
-                True,
                 {"queue": ["test_queue_1", "test_queue_2"]},
                 2,
             ),
@@ -488,7 +460,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"pool": "test_pool_3"},
                 ],
                 True,
-                True,
                 {"pool": ["test_pool_1", "test_pool_2"]},
                 2,
             ),
@@ -500,7 +471,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"state": State.SUCCESS},
                 ],
                 False,
-                True,
                 {"state": ["running", "queued"]},
                 2,
             ),
@@ -512,7 +482,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"duration": 200},
                 ],
                 True,
-                True,
                 {"duration_gte": 100, "duration_lte": 200},
                 3,
             ),
@@ -524,7 +493,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
                 ],
                 True,
-                True,
                 {
                     "end_date_gte": DEFAULT_DATETIME_STR_1,
                     "end_date_lte": DEFAULT_DATETIME_STR_2,
@@ -539,7 +507,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
                 ],
                 True,
-                True,
                 {
                     "start_date_gte": DEFAULT_DATETIME_STR_1,
                     "start_date_lte": DEFAULT_DATETIME_STR_2,
@@ -557,7 +524,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5)},
                 ],
                 False,
-                True,
                 {
                     "execution_date_gte": DEFAULT_DATETIME_1,
                     "execution_date_lte": (DEFAULT_DATETIME_1 + dt.timedelta(days=2)),
@@ -567,14 +533,11 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         ]
     )
     @provide_session
-    def test_should_respond_200(
-        self, _, task_instances, update_extras, single_dag_run, payload, expected_ti_count, session
-    ):
+    def test_should_respond_200(self, _, task_instances, update_extras, payload, expected_ti_count, session):
         self.create_task_instances(
             session,
             update_extras=update_extras,
             task_instances=task_instances,
-            single_dag_run=single_dag_run,
         )
         response = self.client.post(
             "/api/v1/dags/~/dagRuns/~/taskInstances/list",
@@ -593,7 +556,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
                     {"task": "test_1"},
                     {"task": "test_2"},
                 ],
-                True,
                 {"dag_ids": ["latest_only"]},
                 2,
             ),
@@ -601,7 +563,7 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
     )
     @provide_session
     def test_should_respond_200_when_task_instance_properties_are_none(
-        self, _, task_instances, single_dag_run, payload, expected_ti_count, session
+        self, _, task_instances, payload, expected_ti_count, session
     ):
         self.ti_extras.update(
             {
@@ -614,7 +576,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
             session,
             dag_id="latest_only",
             task_instances=task_instances,
-            single_dag_run=single_dag_run,
         )
         response = self.client.post(
             "/api/v1/dags/~/dagRuns/~/taskInstances/list",
@@ -878,7 +839,6 @@ class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
             dag_id=main_dag,
             task_instances=task_instances,
             update_extras=False,
-            single_dag_run=False,
         )
         self.app.dag_bag.sync_to_db()
         response = self.client.post(
@@ -921,7 +881,6 @@ class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
         self.create_task_instances(
             session,
             dag_id=dag_id,
-            single_dag_run=False,
             task_instances=task_instances,
             update_extras=False,
             dag_run_state=State.FAILED,
@@ -1020,7 +979,6 @@ class TestPostClearTaskInstances(TestTaskInstanceEndpoint):
             dag_id="example_python_operator",
             task_instances=task_instances,
             update_extras=False,
-            single_dag_run=False,
         )
         self.app.dag_bag.sync_to_db()
         response = self.client.post(
@@ -1037,7 +995,11 @@ class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
     def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, session):
         self.create_task_instances(session)
         mock_set_task_instance_state.return_value = (
-            session.query(TaskInstance).filter(TaskInstance.task_id == "print_the_context").all()
+            session.query(TaskInstance)
+            .join(TaskInstance.dag_run)
+            .options(contains_eager(TaskInstance.dag_run))
+            .filter(TaskInstance.task_id == "print_the_context")
+            .all()
         )
         response = self.client.post(
             "/api/v1/dags/example_python_operator/updateTaskInstancesState",
@@ -1074,6 +1036,7 @@ class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
             state='failed',
             task_id='print_the_context',
             upstream=True,
+            session=session,
         )
 
     def test_should_raises_401_unauthenticated(self):
diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py
index b5333f1..ba5acae 100644
--- a/tests/api_connexion/schemas/test_dag_run_schema.py
+++ b/tests/api_connexion/schemas/test_dag_run_schema.py
@@ -34,11 +34,14 @@ from tests.test_utils.db import clear_db_runs
 
 DEFAULT_TIME = "2020-06-09T13:59:56.336000+00:00"
 
+SECOND_TIME = "2020-06-10T13:59:56.336000+00:00"
+
 
 class TestDAGRunBase(unittest.TestCase):
     def setUp(self) -> None:
         clear_db_runs()
         self.default_time = DEFAULT_TIME
+        self.second_time = SECOND_TIME
 
     def tearDown(self) -> None:
         clear_db_runs()
@@ -135,7 +138,7 @@ class TestDagRunCollection(TestDAGRunBase):
         dagrun_model_2 = DagRun(
             run_id="my-dag-run-2",
             state='running',
-            execution_date=timezone.parse(self.default_time),
+            execution_date=timezone.parse(self.second_time),
             start_date=timezone.parse(self.default_time),
             run_type=DagRunType.MANUAL.value,
         )
@@ -162,8 +165,8 @@ class TestDagRunCollection(TestDAGRunBase):
                     "dag_run_id": "my-dag-run-2",
                     "end_date": None,
                     "state": "running",
-                    "execution_date": self.default_time,
-                    "logical_date": self.default_time,
+                    "execution_date": self.second_time,
+                    "logical_date": self.second_time,
                     "external_trigger": True,
                     "start_date": self.default_time,
                     "conf": {},
diff --git a/tests/api_connexion/schemas/test_event_log_schema.py b/tests/api_connexion/schemas/test_event_log_schema.py
index 597ecc7..4517ecb 100644
--- a/tests/api_connexion/schemas/test_event_log_schema.py
+++ b/tests/api_connexion/schemas/test_event_log_schema.py
@@ -15,72 +15,58 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
+import pytest
 
-from airflow import DAG
 from airflow.api_connexion.schemas.event_log_schema import (
     EventLogCollection,
     event_log_collection_schema,
     event_log_schema,
 )
-from airflow.models import Log, TaskInstance
-from airflow.operators.dummy import DummyOperator
+from airflow.models import Log
 from airflow.utils import timezone
-from airflow.utils.session import create_session, provide_session
 
 
-class TestEventLogSchemaBase(unittest.TestCase):
-    def setUp(self) -> None:
-        with create_session() as session:
-            session.query(Log).delete()
-        self.default_time = "2020-06-09T13:00:00+00:00"
-        self.default_time2 = '2020-06-11T07:00:00+00:00'
+@pytest.fixture
+def task_instance(session, create_task_instance, request):
+    return create_task_instance(
+        session=session,
+        dag_id="TEST_DAG_ID",
+        task_id="TEST_TASK_ID",
+        execution_date=request.instance.default_time,
+    )
 
-    def tearDown(self) -> None:
-        with create_session() as session:
-            session.query(Log).delete()
 
-    def _create_task_instance(self):
-        with DAG(
-            'TEST_DAG_ID',
-            start_date=timezone.parse(self.default_time),
-            end_date=timezone.parse(self.default_time),
-        ):
-            op1 = DummyOperator(task_id="TEST_TASK_ID", owner="airflow")
-        return TaskInstance(task=op1, execution_date=timezone.parse(self.default_time))
+class TestEventLogSchemaBase:
+    @pytest.fixture(autouse=True)
+    def set_attrs(self):
+        self.default_time = timezone.parse("2020-06-09T13:00:00+00:00")
+        self.default_time2 = timezone.parse('2020-06-11T07:00:00+00:00')
 
 
 class TestEventLogSchema(TestEventLogSchemaBase):
-    @provide_session
-    def test_serialize(self, session):
-        event_log_model = Log(event="TEST_EVENT", task_instance=self._create_task_instance())
-        session.add(event_log_model)
-        session.commit()
-        event_log_model.dttm = timezone.parse(self.default_time)
-        log_model = session.query(Log).first()
-        deserialized_log = event_log_schema.dump(log_model)
+    def test_serialize(self, task_instance):
+        event_log_model = Log(event="TEST_EVENT", task_instance=task_instance)
+        event_log_model.dttm = self.default_time
+        deserialized_log = event_log_schema.dump(event_log_model)
         assert deserialized_log == {
             "event_log_id": event_log_model.id,
             "event": "TEST_EVENT",
             "dag_id": "TEST_DAG_ID",
             "task_id": "TEST_TASK_ID",
-            "execution_date": self.default_time,
+            "execution_date": self.default_time.isoformat(),
             "owner": 'airflow',
-            "when": self.default_time,
+            "when": self.default_time.isoformat(),
             "extra": None,
         }
 
 
 class TestEventLogCollection(TestEventLogSchemaBase):
-    @provide_session
-    def test_serialize(self, session):
-        event_log_model_1 = Log(event="TEST_EVENT_1", task_instance=self._create_task_instance())
-        event_log_model_2 = Log(event="TEST_EVENT_2", task_instance=self._create_task_instance())
+    def test_serialize(self, task_instance):
+        event_log_model_1 = Log(event="TEST_EVENT_1", task_instance=task_instance)
+        event_log_model_2 = Log(event="TEST_EVENT_2", task_instance=task_instance)
         event_logs = [event_log_model_1, event_log_model_2]
-        session.add_all(event_logs)
-        session.commit()
-        event_log_model_1.dttm = timezone.parse(self.default_time)
-        event_log_model_2.dttm = timezone.parse(self.default_time2)
+        event_log_model_1.dttm = self.default_time
+        event_log_model_2.dttm = self.default_time2
         instance = EventLogCollection(event_logs=event_logs, total_entries=2)
         deserialized_event_logs = event_log_collection_schema.dump(instance)
         assert deserialized_event_logs == {
@@ -90,9 +76,9 @@ class TestEventLogCollection(TestEventLogSchemaBase):
                     "event": "TEST_EVENT_1",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'airflow',
-                    "when": self.default_time,
+                    "when": self.default_time.isoformat(),
                     "extra": None,
                 },
                 {
@@ -100,9 +86,9 @@ class TestEventLogCollection(TestEventLogSchemaBase):
                     "event": "TEST_EVENT_2",
                     "dag_id": "TEST_DAG_ID",
                     "task_id": "TEST_TASK_ID",
-                    "execution_date": self.default_time,
+                    "execution_date": self.default_time.isoformat(),
                     "owner": 'airflow',
-                    "when": self.default_time2,
+                    "when": self.default_time2.isoformat(),
                     "extra": None,
                 },
             ],
diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py
index 73895ae..883d936 100644
--- a/tests/api_connexion/schemas/test_task_instance_schema.py
+++ b/tests/api_connexion/schemas/test_task_instance_schema.py
@@ -27,25 +27,29 @@ from airflow.api_connexion.schemas.task_instance_schema import (
     set_task_instance_state_form,
     task_instance_schema,
 )
-from airflow.models import DAG, SlaMiss, TaskInstance as TI
+from airflow.models import SlaMiss, TaskInstance as TI
 from airflow.operators.dummy import DummyOperator
 from airflow.utils.platform import getuser
-from airflow.utils.session import create_session, provide_session
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 
 
-class TestTaskInstanceSchema(unittest.TestCase):
-    def setUp(self):
+class TestTaskInstanceSchema:
+    @pytest.fixture(autouse=True)
+    def set_attrs(self, session, dag_maker):
         self.default_time = datetime(2020, 1, 1)
-        with DAG(dag_id="TEST_DAG_ID"):
+        with dag_maker(dag_id="TEST_DAG_ID", session=session):
             self.task = DummyOperator(task_id="TEST_TASK_ID", start_date=self.default_time)
 
+        self.dr = dag_maker.create_dagrun(execution_date=self.default_time)
+        session.flush()
+
         self.default_ti_init = {
-            "execution_date": self.default_time,
+            "run_id": None,
             "state": State.RUNNING,
         }
         self.default_ti_extras = {
+            "dag_run": self.dr,
             "start_date": self.default_time + dt.timedelta(days=1),
             "end_date": self.default_time + dt.timedelta(days=2),
             "pid": 100,
@@ -54,18 +58,14 @@ class TestTaskInstanceSchema(unittest.TestCase):
             "queue": "default_queue",
         }
 
-    def tearDown(self):
-        with create_session() as session:
-            session.query(TI).delete()
-            session.query(SlaMiss).delete()
+        yield
+
+        session.rollback()
 
-    @provide_session
     def test_task_instance_schema_without_sla(self, session):
         ti = TI(task=self.task, **self.default_ti_init)
         for key, value in self.default_ti_extras.items():
             setattr(ti, key, value)
-        session.add(ti)
-        session.commit()
         serialized_ti = task_instance_schema.dump((ti, None))
         expected_json = {
             "dag_id": "TEST_DAG_ID",
@@ -91,19 +91,17 @@ class TestTaskInstanceSchema(unittest.TestCase):
         }
         assert serialized_ti == expected_json
 
-    @provide_session
     def test_task_instance_schema_with_sla(self, session):
-        ti = TI(task=self.task, **self.default_ti_init)
-        for key, value in self.default_ti_extras.items():
-            setattr(ti, key, value)
         sla_miss = SlaMiss(
             task_id="TEST_TASK_ID",
             dag_id="TEST_DAG_ID",
             execution_date=self.default_time,
         )
-        session.add(ti)
         session.add(sla_miss)
-        session.commit()
+        session.flush()
+        ti = TI(task=self.task, **self.default_ti_init)
+        for key, value in self.default_ti_extras.items():
+            setattr(ti, key, value)
         serialized_ti = task_instance_schema.dump((ti, sla_miss))
         expected_json = {
             "dag_id": "TEST_DAG_ID",
@@ -177,19 +175,17 @@ class TestClearTaskInstanceFormSchema(unittest.TestCase):
             clear_task_instance_form.load(payload)
 
 
-class TestSetTaskInstanceStateFormSchema(unittest.TestCase):
-    def setUp(self) -> None:
-        super().setUp()
-        self.current_input = {
-            "dry_run": True,
-            "task_id": "print_the_context",
-            "execution_date": "2020-01-01T00:00:00+00:00",
-            "include_upstream": True,
-            "include_downstream": True,
-            "include_future": True,
-            "include_past": True,
-            "new_state": "failed",
-        }
+class TestSetTaskInstanceStateFormSchema:
+    current_input = {
+        "dry_run": True,
+        "task_id": "print_the_context",
+        "execution_date": "2020-01-01T00:00:00+00:00",
+        "include_upstream": True,
+        "include_downstream": True,
+        "include_future": True,
+        "include_past": True,
+        "new_state": "failed",
+    }
 
     def test_success(self):
         result = set_task_instance_state_form.load(self.current_input)
diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py
index fd8404a..4bd7013 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -39,7 +39,7 @@ from tests.test_utils.db import clear_db_dags, clear_db_runs
 
 dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1])
 
-DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1))
+DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc)
 TEST_DAG_FOLDER = os.path.join(os.path.dirname(dag_folder_path), 'dags')
 TEST_DAG_ID = 'unit_tests'
 
@@ -357,7 +357,7 @@ class TestCliDags(unittest.TestCase):
         assert "airflow" in out
         assert "paused" in out
         assert "airflow/example_dags/example_complex.py" in out
-        assert "False" in out
+        assert "- dag_id:" in out
 
     def test_cli_list_dag_runs(self):
         dag_command.dag_trigger(
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 9231e62..983416e 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -23,7 +23,7 @@ import os
 import re
 import unittest
 from contextlib import redirect_stdout
-from datetime import datetime, timedelta
+from datetime import datetime
 from unittest import mock
 
 import pytest
@@ -32,17 +32,17 @@ from parameterized import parameterized
 from airflow.cli import cli_parser
 from airflow.cli.commands import task_command
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, DagRunNotFound
 from airflow.models import DagBag, DagRun, TaskInstance
 from airflow.utils import timezone
-from airflow.utils.cli import get_dag
+from airflow.utils.dates import days_ago
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_runs
 
-DEFAULT_DATE = timezone.make_aware(datetime(2016, 1, 1))
+DEFAULT_DATE = days_ago(1)
 ROOT_FOLDER = os.path.realpath(
     os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
 )
@@ -58,13 +58,22 @@ def reset(dag_id):
 
 # TODO: Check if tests needs side effects - locally there's missing DAG
 class TestCliTasks(unittest.TestCase):
+    run_id = 'TEST_RUN_ID'
+    dag_id = 'example_python_operator'
+
     @classmethod
     def setUpClass(cls):
         cls.dagbag = DagBag(include_examples=True)
         cls.parser = cli_parser.get_parser()
         clear_db_runs()
 
-    def tearDown(self) -> None:
+        cls.dag = cls.dagbag.get_dag(cls.dag_id)
+        cls.dag_run = cls.dag.create_dagrun(
+            state=State.NONE, run_id=cls.run_id, run_type=DagRunType.MANUAL, execution_date=DEFAULT_DATE
+        )
+
+    @classmethod
+    def tearDownClass(cls) -> None:
         clear_db_runs()
 
     def test_cli_list_tasks(self):
@@ -89,76 +98,33 @@ class TestCliTasks(unittest.TestCase):
 
     def test_test_with_existing_dag_run(self):
         """Test the `airflow test` command"""
-        dag_id = 'example_python_operator'
-        run_id = 'TEST_RUN_ID'
         task_id = 'print_the_context'
-        dag = self.dagbag.get_dag(dag_id)
-
-        dag.create_dagrun(state=State.NONE, run_id=run_id, run_type=DagRunType.MANUAL, external_trigger=True)
 
-        args = self.parser.parse_args(["tasks", "test", dag_id, task_id, run_id])
+        args = self.parser.parse_args(["tasks", "test", self.dag_id, task_id, DEFAULT_DATE.isoformat()])
 
         with redirect_stdout(io.StringIO()) as stdout:
             task_command.task_test(args)
 
         # Check that prints, and log messages, are shown
-        assert f"Marking task as SUCCESS. dag_id={dag_id}, task_id={task_id}" in stdout.getvalue()
-
-    @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
-    def test_run_naive_taskinstance(self, mock_local_job):
-        """
-        Test that we can run naive (non-localized) task instances
-        """
-        naive_date = datetime(2016, 1, 1)
-        dag_id = 'test_run_ignores_all_dependencies'
-
-        dag = self.dagbag.get_dag('test_run_ignores_all_dependencies')
-
-        task0_id = 'test_run_dependent_task'
-        args0 = [
-            'tasks',
-            'run',
-            '--ignore-all-dependencies',
-            '--local',
-            dag_id,
-            task0_id,
-            naive_date.isoformat(),
-        ]
-
-        task_command.task_run(self.parser.parse_args(args0), dag=dag)
-        mock_local_job.assert_called_once_with(
-            task_instance=mock.ANY,
-            mark_success=False,
-            ignore_all_deps=True,
-            ignore_depends_on_past=False,
-            ignore_task_deps=False,
-            ignore_ti_state=False,
-            pickle_id=None,
-            pool=None,
-        )
+        assert f"Marking task as SUCCESS. dag_id={self.dag_id}, task_id={task_id}" in stdout.getvalue()
 
     @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
     def test_run_with_existing_dag_run_id(self, mock_local_job):
         """
         Test that we can run with existing dag_run_id
         """
-        dag_id = 'test_run_ignores_all_dependencies'
-
-        dag = self.dagbag.get_dag(dag_id)
-        task0_id = 'test_run_dependent_task'
-        run_id = 'TEST_RUN_ID'
-        dag.create_dagrun(state=State.NONE, run_id=run_id, run_type=DagRunType.MANUAL, external_trigger=True)
+        task0_id = self.dag.task_ids[0]
         args0 = [
             'tasks',
             'run',
             '--ignore-all-dependencies',
             '--local',
-            dag_id,
+            self.dag_id,
             task0_id,
-            run_id,
+            self.run_id,
         ]
 
-        task_command.task_run(self.parser.parse_args(args0), dag=dag)
+        task_command.task_run(self.parser.parse_args(args0), dag=self.dag)
         mock_local_job.assert_called_once_with(
             task_instance=mock.ANY,
             mark_success=False,
@@ -188,21 +154,8 @@ class TestCliTasks(unittest.TestCase):
             task0_id,
             run_id,
         ]
-        with self.assertRaises(AirflowException) as err:
+        with self.assertRaises(DagRunNotFound):
             task_command.task_run(self.parser.parse_args(args0), dag=dag)
-        assert str(err.exception) == f"DagRun with run_id: {run_id} not found"
-
-    def test_cli_test(self):
-        task_command.task_test(
-            self.parser.parse_args(
-                ['tasks', 'test', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()]
-            )
-        )
-        task_command.task_test(
-            self.parser.parse_args(
-                ['tasks', 'test', 'example_bash_operator', 'runme_0', '--dry-run', DEFAULT_DATE.isoformat()]
-            )
-        )
 
     def test_cli_test_with_params(self):
         task_command.task_test(
@@ -251,13 +204,6 @@ class TestCliTasks(unittest.TestCase):
         assert 'foo=bar' in output
         assert 'AIRFLOW_TEST_MODE=True' in output
 
-    def test_cli_run(self):
-        task_command.task_run(
-            self.parser.parse_args(
-                ['tasks', 'run', 'example_bash_operator', 'runme_0', '--local', DEFAULT_DATE.isoformat()]
-            )
-        )
-
     @parameterized.expand(
         [
             ("--ignore-all-dependencies",),
@@ -307,7 +253,7 @@ class TestCliTasks(unittest.TestCase):
         """
         with redirect_stdout(io.StringIO()) as stdout:
             task_command.task_render(
-                self.parser.parse_args(['tasks', 'render', 'tutorial', 'templated', DEFAULT_DATE.isoformat()])
+                self.parser.parse_args(['tasks', 'render', 'tutorial', 'templated', '2016-01-01'])
             )
 
         output = stdout.getvalue()
@@ -326,7 +272,6 @@ class TestCliTasks(unittest.TestCase):
             AirflowException,
             match=re.escape("You cannot use the --pickle option when using DAG.cli() method."),
         ):
-            dag = self.dagbag.get_dag('test_run_ignores_all_dependencies')
             task_command.task_run(
                 self.parser.parse_args(
                     [
@@ -339,13 +284,13 @@ class TestCliTasks(unittest.TestCase):
                         pickle_id,
                     ]
                 ),
-                dag,
+                self.dag,
             )
 
     def test_task_state(self):
         task_command.task_state(
             self.parser.parse_args(
-                ['tasks', 'state', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat()]
+                ['tasks', 'state', self.dag_id, 'print_the_context', DEFAULT_DATE.isoformat()]
             )
         )
 
@@ -395,7 +340,7 @@ class TestCliTasks(unittest.TestCase):
         """
         task_states_for_dag_run should return an AirflowException when invalid dag id is passed
         """
-        with pytest.raises(AirflowException, match="DagRun does not exist."):
+        with pytest.raises(DagRunNotFound):
             default_date2 = timezone.make_aware(datetime(2016, 1, 9))
             task_command.task_states_for_dag_run(
                 self.parser.parse_args(
@@ -426,31 +371,6 @@ class TestCliTasks(unittest.TestCase):
         )
         task_command.task_clear(args)
 
-    @pytest.mark.quarantined
-    def test_local_run(self):
-        args = self.parser.parse_args(
-            [
-                'tasks',
-                'run',
-                'example_python_operator',
-                'print_the_context',
-                '2018-04-27T08:39:51.298439+00:00',
-                '--interactive',
-                '--subdir',
-                '/root/dags/example_python_operator.py',
-            ]
-        )
-
-        dag = get_dag(args.subdir, args.dag_id)
-        reset(dag.dag_id)
-
-        task_command.task_run(args)
-        task = dag.get_task(task_id=args.task_id)
-        ti = TaskInstance(task, args.execution_date)
-        ti.refresh_from_db()
-        state = ti.current_state()
-        assert state == State.SUCCESS
-
 
 # For this test memory spins out of control on Python 3.6. TODO(potiuk): FIXME")
 @pytest.mark.quarantined
@@ -650,66 +570,3 @@ class TestLogsfromTaskRunCommand(unittest.TestCase):
                 assert captured.output == ["WARNING:foo.bar:not redirected"]
 
         settings.DONOT_MODIFY_HANDLERS = old_value
-
-
-class TestCliTaskBackfill(unittest.TestCase):
-    @classmethod
-    def setUpClass(cls):
-        cls.dagbag = DagBag(include_examples=True)
-
-    def setUp(self):
-        clear_db_runs()
-        clear_db_pools()
-
-        self.parser = cli_parser.get_parser()
-
-    def test_run_ignores_all_dependencies(self):
-        """
-        Test that run respects ignore_all_dependencies
-        """
-        dag_id = 'test_run_ignores_all_dependencies'
-
-        dag = self.dagbag.get_dag('test_run_ignores_all_dependencies')
-        dag.clear()
-
-        task0_id = 'test_run_dependent_task'
-        args0 = ['tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id, DEFAULT_DATE.isoformat()]
-        task_command.task_run(self.parser.parse_args(args0))
-        ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE)
-
-        ti_dependent0.refresh_from_db()
-        assert ti_dependent0.state == State.FAILED
-
-        task1_id = 'test_run_dependency_task'
-        args1 = [
-            'tasks',
-            'run',
-            '--ignore-all-dependencies',
-            dag_id,
-            task1_id,
-            (DEFAULT_DATE + timedelta(days=1)).isoformat(),
-        ]
-        task_command.task_run(self.parser.parse_args(args1))
-
-        ti_dependency = TaskInstance(
-            task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1)
-        )
-        ti_dependency.refresh_from_db()
-        assert ti_dependency.state == State.FAILED
-
-        task2_id = 'test_run_dependent_task'
-        args2 = [
-            'tasks',
-            'run',
-            '--ignore-all-dependencies',
-            dag_id,
-            task2_id,
-            (DEFAULT_DATE + timedelta(days=1)).isoformat(),
-        ]
-        task_command.task_run(self.parser.parse_args(args2))
-
-        ti_dependent = TaskInstance(
-            task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1)
-        )
-        ti_dependent.refresh_from_db()
-        assert ti_dependent.state == State.SUCCESS
diff --git a/tests/conftest.py b/tests/conftest.py
index 745ec9f..b18a472 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -14,10 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import json
 import os
 import subprocess
 import sys
-from contextlib import ExitStack
+from contextlib import ExitStack, suppress
 from datetime import datetime, timedelta
 
 import freezegun
@@ -466,8 +467,8 @@ def dag_maker(request):
 
     want_serialized = False
 
-    # Allow changing default serialized behaviour with `@ptest.mark.need_serialized_dag` or
-    # `@ptest.mark.need_serialized_dag(False)`
+    # Allow changing default serialized behaviour with `@pytest.mark.need_serialized_dag` or
+    # `@pytest.mark.need_serialized_dag(False)`
     serialized_marker = request.node.get_closest_marker("need_serialized_dag")
     if serialized_marker:
         (want_serialized,) = serialized_marker.args or (True,)
@@ -488,6 +489,15 @@ def dag_maker(request):
         def _serialized_dag(self):
             return self.serialized_model.dag
 
+        def get_serialized_data(self):
+            try:
+                data = self.serialized_model.data
+            except AttributeError:
+                raise RuntimeError("DAG serialization not requested")
+            if isinstance(data, str):
+                return json.loads(data)
+            return data
+
         def __exit__(self, type, value, traceback):
             from airflow.models import DagModel
             from airflow.models.serialized_dag import SerializedDagModel
@@ -497,7 +507,7 @@ def dag_maker(request):
             if type is not None:
                 return
 
-            dag.clear()
+            dag.clear(session=self.session)
             dag.sync_to_db(self.session)
             self.dag_model = self.session.query(DagModel).get(dag.dag_id)
 
@@ -511,6 +521,7 @@ def dag_maker(request):
                 self.dagbag.bag_dag(self.dag, self.dag)
 
         def create_dagrun(self, **kwargs):
+            from airflow.timetables.base import DataInterval
             from airflow.utils.state import State
 
             dag = self.dag
@@ -525,7 +536,13 @@ def dag_maker(request):
             # explicitly, or pass run_type for inference in dag.create_dagrun().
             if "run_id" not in kwargs and "run_type" not in kwargs:
                 kwargs["run_id"] = "test"
+            # Fill data_interval is not provided.
+            if not kwargs.get("data_interval"):
+                kwargs["data_interval"] = DataInterval.exact(kwargs["execution_date"])
+
             self.dag_run = dag.create_dagrun(**kwargs)
+            for ti in self.dag_run.task_instances:
+                ti.refresh_from_task(dag.get_task(ti.task_id))
             return self.dag_run
 
         def __call__(
@@ -587,7 +604,8 @@ def dag_maker(request):
         yield factory
     finally:
         factory.cleanup()
-        del factory.session
+        with suppress(AttributeError):
+            del factory.session
 
 
 @pytest.fixture
@@ -622,6 +640,7 @@ def create_dummy_dag(dag_maker):
         on_failure_callback=None,
         on_retry_callback=None,
         email=None,
+        with_dagrun=True,
         **kwargs,
     ):
         with dag_maker(dag_id, **kwargs) as dag:
@@ -637,7 +656,69 @@ def create_dummy_dag(dag_maker):
                 pool=pool,
                 trigger_rule=trigger_rule,
             )
-        dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
+        if with_dagrun:
+            dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
         return dag, op
 
     return create_dag
+
+
+@pytest.fixture
+def create_task_instance(dag_maker, create_dummy_dag):
+    """
+    Create a TaskInstance, and associated DB rows (DagRun, DagModel, etc)
+
+    Uses ``create_dummy_dag`` to create the dag structure.
+    """
+
+    def maker(execution_date=None, dagrun_state=None, state=None, run_id='test', **kwargs):
+        if execution_date is None:
+            from airflow.utils import timezone
+
+            execution_date = timezone.utcnow()
+        create_dummy_dag(with_dagrun=False, **kwargs)
+
+        dr = dag_maker.create_dagrun(execution_date=execution_date, state=dagrun_state, run_id=run_id)
+        ti = dr.task_instances[0]
+        ti.state = state
+
+        return ti
+
+    return maker
+
+
+@pytest.fixture()
+def create_task_instance_of_operator(dag_maker):
+    def _create_task_instance(
+        operator_class,
+        *,
+        dag_id,
+        execution_date=None,
+        session=None,
+        **operator_kwargs,
+    ):
+        with dag_maker(dag_id=dag_id, session=session):
+            operator_class(**operator_kwargs)
+        (ti,) = dag_maker.create_dagrun(execution_date=execution_date).task_instances
+        return ti
+
+    return _create_task_instance
+
+
+@pytest.fixture()
+def create_task_of_operator(dag_maker):
+    def _create_task_of_operator(operator_class, *, dag_id, session=None, **operator_kwargs):
+        with dag_maker(dag_id=dag_id, session=session):
+            task = operator_class(**operator_kwargs)
+        return task
+
+    return _create_task_of_operator
+
+
+@pytest.fixture
+def session():
+    from airflow.utils.session import create_session
+
+    with create_session() as session:
+        yield session
+        session.rollback()
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index d8a5687..9131fbf 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -17,7 +17,6 @@
 # under the License.
 
 import logging
-import multiprocessing
 import os
 import signal
 from datetime import timedelta
@@ -29,13 +28,13 @@ import pytest
 from airflow import settings
 from airflow.exceptions import AirflowException, AirflowTaskTimeout
 from airflow.hooks.base import BaseHook
-from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.models import DagBag, TaskFail, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.operators.bash import BashOperator
 from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
@@ -108,10 +107,18 @@ class TestCore:
 
         captain_hook.run("drop table operator_test_table")
 
-    def test_clear_api(self):
+    def test_clear_api(self, session):
         task = self.dag_bash.tasks[0]
-        task.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, upstream=True, downstream=True)
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+
+        dr = self.dag_bash.create_dagrun(
+            run_type=DagRunType.MANUAL,
+            state=State.RUNNING,
+            execution_date=days_ago(1),
+            session=session,
+        )
+        task.clear(start_date=dr.execution_date, end_date=dr.execution_date, upstream=True, downstream=True)
+        ti = dr.get_task_instance(task.task_id, session=session)
+        ti.task = task
         ti.are_dependents_done()
 
     def test_illegal_args(self, dag_maker):
@@ -268,15 +275,16 @@ class TestCore:
         dag_maker.create_dagrun()
         op.resolve_template_files()
 
-    def test_task_get_template(self):
-        ti = TaskInstance(task=self.runme_0, execution_date=DEFAULT_DATE)
-        ti.dag = self.dag_bash
-        self.dag_bash.create_dagrun(
+    def test_task_get_template(self, session):
+        dr = self.dag_bash.create_dagrun(
             run_type=DagRunType.MANUAL,
             state=State.RUNNING,
             execution_date=DEFAULT_DATE,
             data_interval=(DEFAULT_DATE, DEFAULT_DATE + timedelta(days=1)),
+            session=session,
         )
+        ti = TaskInstance(task=self.runme_0, run_id=dr.run_id)
+        ti.dag = self.dag_bash
         ti.run(ignore_ti_state=True)
         context = ti.get_template_context()
 
@@ -314,64 +322,12 @@ class TestCore:
             assert value == expected_value
             assert [str(m.message) for m in recorder] == [message]
 
-    def test_local_task_job(self):
-        TI = TaskInstance
-        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
-        job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
-        job.run()
-
-    def test_raw_job(self):
-        TI = TaskInstance
-        ti = TI(task=self.runme_0, execution_date=DEFAULT_DATE)
-        ti.dag = self.dag_bash
-        self.dag_bash.create_dagrun(
-            run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE
-        )
-        ti.run(ignore_ti_state=True)
-
     def test_bad_trigger_rule(self, dag_maker):
         with pytest.raises(AirflowException):
             with dag_maker():
                 DummyOperator(task_id='test_bad_trigger', trigger_rule="non_existent")
             dag_maker.create_dagrun()
 
-    def test_terminate_task(self):
-        """If a task instance's db state get deleted, it should fail"""
-        from airflow.executors.sequential_executor import SequentialExecutor
-
-        TI = TaskInstance
-        dag = self.dagbag.dags.get('test_utils')
-        task = dag.task_dict.get('sleeps_forever')
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
-        job = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
-
-        # Running task instance asynchronously
-        proc = multiprocessing.Process(target=job.run)
-        proc.start()
-        sleep(5)
-        settings.engine.dispose()
-        session = settings.Session()
-        ti.refresh_from_db(session=session)
-        # making sure it's actually running
-        assert State.RUNNING == ti.state
-        ti = (
-            session.query(TI)
-            .filter_by(dag_id=task.dag_id, task_id=task.task_id, execution_date=DEFAULT_DATE)
-            .one()
-        )
-
-        # deleting the instance should result in a failure
-        session.delete(ti)
-        session.commit()
-        # waiting for the async task to finish
-        proc.join()
-
-        # making sure that the task ended up as failed
-        ti.refresh_from_db(session=session)
-        assert State.FAILED == ti.state
-        session.close()
-
     def test_task_fail_duration(self, dag_maker):
         """If a task fails, the duration should be recorded in TaskFail"""
         with dag_maker() as dag:
@@ -382,6 +338,7 @@ class TestCore:
                 execution_timeout=timedelta(seconds=3),
                 retry_delay=timedelta(seconds=0),
             )
+        dag_maker.create_dagrun()
         session = settings.Session()
         try:
             op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/core/test_sentry.py b/tests/core/test_sentry.py
index 44a39d9..5a5fa56 100644
--- a/tests/core/test_sentry.py
+++ b/tests/core/test_sentry.py
@@ -18,14 +18,12 @@
 
 import datetime
 import importlib
-import unittest
-from unittest.mock import MagicMock, Mock
 
+import pytest
 from freezegun import freeze_time
 from sentry_sdk import configure_scope
 
-from airflow.models import TaskInstance
-from airflow.settings import Session
+from airflow.operators.python import PythonOperator
 from airflow.utils import timezone
 from airflow.utils.state import State
 from tests.test_utils.config import conf_vars
@@ -33,7 +31,7 @@ from tests.test_utils.config import conf_vars
 EXECUTION_DATE = timezone.utcnow()
 DAG_ID = "test_dag"
 TASK_ID = "test_task"
-OPERATOR = "test_operator"
+OPERATOR = "PythonOperator"
 TRY_NUMBER = 1
 STATE = State.SUCCESS
 TEST_SCOPE = {
@@ -60,46 +58,49 @@ CRUMB = {
 }
 
 
-class TestSentryHook(unittest.TestCase):
-    @conf_vars({('sentry', 'sentry_on'): 'True'})
-    def setUp(self):
-        from airflow import sentry
+class TestSentryHook:
+    @pytest.fixture
+    def task_instance(self, dag_maker):
+        # Mock the Dag
+        with dag_maker(DAG_ID):
+            task = PythonOperator(task_id=TASK_ID, python_callable=int)
 
-        importlib.reload(sentry)
-        self.sentry = sentry.ConfiguredSentry()
+        dr = dag_maker.create_dagrun(execution_date=EXECUTION_DATE)
+        ti = dr.task_instances[0]
+        ti.state = STATE
+        ti.task = task
+        dag_maker.session.flush()
 
-        # Mock the Dag
-        self.dag = Mock(dag_id=DAG_ID, params=[])
-        self.dag.task_ids = [TASK_ID]
+        yield ti
 
-        # Mock the task
-        self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1)
-        self.task.__class__.__name__ = OPERATOR
+        dag_maker.session.rollback()
 
-        self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE)
-        self.ti.operator = OPERATOR
-        self.ti.state = STATE
+    @pytest.fixture
+    def sentry(self):
+        with conf_vars({('sentry', 'sentry_on'): 'True'}):
+            from airflow import sentry
 
-        self.dag.get_task_instances = MagicMock(return_value=[self.ti])
+            importlib.reload(sentry)
+            yield sentry.Sentry
 
-        self.session = Session()
+        importlib.reload(sentry)
 
-    def test_add_tagging(self):
+    def test_add_tagging(self, sentry, task_instance):
         """
         Test adding tags.
         """
-        self.sentry.add_tagging(task_instance=self.ti)
+        sentry.add_tagging(task_instance=task_instance)
         with configure_scope() as scope:
             for key, value in scope._tags.items():
                 assert TEST_SCOPE[key] == value
 
     @freeze_time(CRUMB_DATE.isoformat())
-    def test_add_breadcrumbs(self):
+    def test_add_breadcrumbs(self, sentry, task_instance):
         """
         Test adding breadcrumbs.
         """
-        self.sentry.add_tagging(task_instance=self.ti)
-        self.sentry.add_breadcrumbs(task_instance=self.ti, session=self.session)
+        sentry.add_tagging(task_instance=task_instance)
+        sentry.add_breadcrumbs(task_instance=task_instance)
 
         with configure_scope() as scope:
             test_crumb = scope._breadcrumbs.pop()
diff --git a/tests/dag_processing/test_manager.py b/tests/dag_processing/test_manager.py
index ac2d4dc..2c62939 100644
--- a/tests/dag_processing/test_manager.py
+++ b/tests/dag_processing/test_manager.py
@@ -51,7 +51,8 @@ from airflow.utils import timezone
 from airflow.utils.callback_requests import CallbackRequest, TaskCallbackRequest
 from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
+from airflow.utils.types import DagRunType
 from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
@@ -110,6 +111,9 @@ class TestDagFileProcessorManager:
     def setup_method(self):
         clear_db_runs()
 
+    def teardown_class(self):
+        clear_db_runs()
+
     def run_processor_manager_one_loop(self, manager, parent_pipe):
         if not manager._async_mode:
             parent_pipe.send(DagParsingSignal.AGENT_RUN_ONCE)
@@ -432,16 +436,23 @@ class TestDagFileProcessorManager:
             dag.sync_to_db()
             task = dag.get_task(task_id='run_this_first')
 
-            ti = TI(task, DEFAULT_DATE, State.RUNNING)
+            dag_run = dag.create_dagrun(
+                state=DagRunState.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+            )
+
+            ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING)
             local_job = LJ(ti)
             local_job.state = State.SHUTDOWN
 
             session.add(local_job)
-            session.commit()
+            session.flush()
 
             ti.job_id = local_job.id
             session.add(ti)
-            session.commit()
+            session.flush()
 
             manager._last_zombie_query_time = timezone.utcnow() - timedelta(
                 seconds=manager._zombie_threshold_secs + 1
@@ -455,7 +466,7 @@ class TestDagFileProcessorManager:
             assert isinstance(requests[0].simple_task_instance, SimpleTaskInstance)
             assert ti.dag_id == requests[0].simple_task_instance.dag_id
             assert ti.task_id == requests[0].simple_task_instance.task_id
-            assert ti.execution_date == requests[0].simple_task_instance.execution_date
+            assert ti.run_id == requests[0].simple_task_instance.run_id
 
             session.query(TI).delete()
             session.query(LJ).delete()
@@ -475,19 +486,26 @@ class TestDagFileProcessorManager:
                 session.query(LJ).delete()
                 dag = dagbag.get_dag('test_example_bash_operator')
                 dag.sync_to_db()
+
+                dag_run = dag.create_dagrun(
+                    state=DagRunState.RUNNING,
+                    execution_date=DEFAULT_DATE,
+                    run_type=DagRunType.SCHEDULED,
+                    session=session,
+                )
                 task = dag.get_task(task_id='run_this_last')
 
-                ti = TI(task, DEFAULT_DATE, State.RUNNING)
+                ti = TI(task, run_id=dag_run.run_id, state=State.RUNNING)
                 local_job = LJ(ti)
                 local_job.state = State.SHUTDOWN
                 session.add(local_job)
-                session.commit()
+                session.flush()
 
                 # TODO: If there was an actual Relationship between TI and Job
                 # we wouldn't need this extra commit
                 session.add(ti)
                 ti.job_id = local_job.id
-                session.commit()
+                session.flush()
 
                 expected_failure_callback_requests = [
                     TaskCallbackRequest(
diff --git a/tests/dag_processing/test_processor.py b/tests/dag_processing/test_processor.py
index 54f0e9e..43d4b86 100644
--- a/tests/dag_processing/test_processor.py
+++ b/tests/dag_processing/test_processor.py
@@ -19,7 +19,6 @@
 
 import datetime
 import os
-from tempfile import NamedTemporaryFile
 from unittest import mock
 from unittest.mock import MagicMock, patch
 from zipfile import ZipFile
@@ -37,6 +36,7 @@ from airflow.utils.callback_requests import TaskCallbackRequest
 from airflow.utils.dates import days_ago
 from airflow.utils.session import create_session
 from airflow.utils.state import State
+from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars, env_vars
 from tests.test_utils.db import (
     clear_db_dags,
@@ -82,9 +82,10 @@ class TestDagFileProcessor:
         clear_db_jobs()
         clear_db_serialized_dags()
 
-    def setup_method(self):
+    def setup_class(self):
         self.clean_db()
 
+    def setup_method(self):
         # Speed up some tests by not running the tasks, just look at what we
         # enqueue!
         self.null_exec = MockExecutor()
@@ -329,23 +330,24 @@ class TestDagFileProcessor:
         with create_session() as session:
             session.query(TaskInstance).delete()
             dag = dagbag.get_dag('example_branch_operator')
+            dagrun = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+            )
             task = dag.get_task(task_id='run_this_first')
-
-            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
-
+            ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING)
             session.add(ti)
-            session.commit()
 
-            requests = [
-                TaskCallbackRequest(
-                    full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
-                )
-            ]
-            dag_file_processor.execute_callbacks(dagbag, requests)
-            mock_ti_handle_failure.assert_called_once_with(
-                error="Message",
-                test_mode=conf.getboolean('core', 'unit_test_mode'),
-            )
+        requests = [
+            TaskCallbackRequest(full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message")
+        ]
+        dag_file_processor.execute_callbacks(dagbag, requests)
+        mock_ti_handle_failure.assert_called_once_with(
+            error="Message",
+            test_mode=conf.getboolean('core', 'unit_test_mode'),
+        )
 
     def test_failure_callbacks_should_not_drop_hostname(self):
         dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False)
@@ -355,36 +357,46 @@ class TestDagFileProcessor:
         with create_session() as session:
             dag = dagbag.get_dag('example_branch_operator')
             task = dag.get_task(task_id='run_this_first')
-
-            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
+            dagrun = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+            )
+            ti = TaskInstance(task, run_id=dagrun.run_id, state=State.RUNNING)
             ti.hostname = "test_hostname"
             session.add(ti)
 
+        requests = [
+            TaskCallbackRequest(full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message")
+        ]
+        dag_file_processor.execute_callbacks(dagbag, requests)
+
         with create_session() as session:
-            requests = [
-                TaskCallbackRequest(
-                    full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message"
-                )
-            ]
-            dag_file_processor.execute_callbacks(dagbag, requests)
             tis = session.query(TaskInstance)
             assert tis[0].hostname == "test_hostname"
 
-    def test_process_file_should_failure_callback(self):
+    def test_process_file_should_failure_callback(self, monkeypatch, tmp_path):
+        callback_file = tmp_path.joinpath("callback.txt")
+        callback_file.touch()
+        monkeypatch.setenv("AIRFLOW_CALLBACK_FILE", str(callback_file))
         dag_file = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), '../dags/test_on_failure_callback.py'
         )
         dagbag = DagBag(dag_folder=dag_file, include_examples=False)
         dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock())
-        with create_session() as session, NamedTemporaryFile(delete=False) as callback_file:
-            session.query(TaskInstance).delete()
-            dag = dagbag.get_dag('test_om_failure_callback_dag')
-            task = dag.get_task(task_id='test_om_failure_callback_task')
-
-            ti = TaskInstance(task, DEFAULT_DATE, State.RUNNING)
 
-            session.add(ti)
-            session.commit()
+        dag = dagbag.get_dag('test_om_failure_callback_dag')
+        task = dag.get_task(task_id='test_om_failure_callback_task')
+        with create_session() as session:
+            dagrun = dag.create_dagrun(
+                state=State.RUNNING,
+                execution_date=DEFAULT_DATE,
+                run_type=DagRunType.SCHEDULED,
+                session=session,
+            )
+            (ti,) = dagrun.task_instances
+            ti.refresh_from_task(task)
 
             requests = [
                 TaskCallbackRequest(
@@ -393,14 +405,9 @@ class TestDagFileProcessor:
                     msg="Message",
                 )
             ]
-            callback_file.close()
-
-            with mock.patch.dict("os.environ", {"AIRFLOW_CALLBACK_FILE": callback_file.name}):
-                dag_file_processor.process_file(dag_file, requests)
-            with open(callback_file.name) as callback_file2:
-                content = callback_file2.read()
-            assert "Callback fired" == content
-            os.remove(callback_file.name)
+            dag_file_processor.process_file(dag_file, requests, session=session)
+
+        assert "Callback fired" == callback_file.read_text()
 
     @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"})
     def test_add_unparseable_file_before_sched_start_creates_import_error(self, tmpdir):
diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py
index 2251878..561d551 100644
--- a/tests/executors/test_base_executor.py
+++ b/tests/executors/test_base_executor.py
@@ -15,60 +15,59 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-import unittest
-from datetime import datetime, timedelta
+from datetime import timedelta
 from unittest import mock
 
 from airflow.executors.base_executor import BaseExecutor
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import DAG
-from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+from airflow.models.taskinstance import TaskInstanceKey
+from airflow.utils import timezone
 from airflow.utils.state import State
 
 
-class TestBaseExecutor(unittest.TestCase):
-    def test_get_event_buffer(self):
-        executor = BaseExecutor()
+def test_get_event_buffer():
+    executor = BaseExecutor()
+
+    date = timezone.utcnow()
+    try_number = 1
+    key1 = TaskInstanceKey("my_dag1", "my_task1", date, try_number)
+    key2 = TaskInstanceKey("my_dag2", "my_task1", date, try_number)
+    key3 = TaskInstanceKey("my_dag2", "my_task2", date, try_number)
+    state = State.SUCCESS
+    executor.event_buffer[key1] = state, None
+    executor.event_buffer[key2] = state, None
+    executor.event_buffer[key3] = state, None
+
+    assert len(executor.get_event_buffer(("my_dag1",))) == 1
+    assert len(executor.get_event_buffer()) == 2
+    assert len(executor.event_buffer) == 0
+
 
-        date = datetime.utcnow()
-        try_number = 1
-        key1 = TaskInstanceKey("my_dag1", "my_task1", date, try_number)
-        key2 = TaskInstanceKey("my_dag2", "my_task1", date, try_number)
-        key3 = TaskInstanceKey("my_dag2", "my_task2", date, try_number)
-        state = State.SUCCESS
-        executor.event_buffer[key1] = state, None
-        executor.event_buffer[key2] = state, None
-        executor.event_buffer[key3] = state, None
+@mock.patch('airflow.executors.base_executor.BaseExecutor.sync')
+@mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
+@mock.patch('airflow.executors.base_executor.Stats.gauge')
+def test_gauge_executor_metrics(mock_stats_gauge, mock_trigger_tasks, mock_sync):
+    executor = BaseExecutor()
+    executor.heartbeat()
+    calls = [
+        mock.call('executor.open_slots', mock.ANY),
+        mock.call('executor.queued_tasks', mock.ANY),
+        mock.call('executor.running_tasks', mock.ANY),
+    ]
+    mock_stats_gauge.assert_has_calls(calls)
 
-        assert len(executor.get_event_buffer(("my_dag1",))) == 1
-        assert len(executor.get_event_buffer()) == 2
-        assert len(executor.event_buffer) == 0
 
-    @mock.patch('airflow.executors.base_executor.BaseExecutor.sync')
-    @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')
-    @mock.patch('airflow.executors.base_executor.Stats.gauge')
-    def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
-        executor = BaseExecutor()
-        executor.heartbeat()
-        calls = [
-            mock.call('executor.open_slots', mock.ANY),
-            mock.call('executor.queued_tasks', mock.ANY),
-            mock.call('executor.running_tasks', mock.ANY),
-        ]
-        mock_stats_gauge.assert_has_calls(calls)
+def test_try_adopt_task_instances(dag_maker):
+    date = timezone.utcnow()
+    start_date = date - timedelta(days=2)
 
-    def test_try_adopt_task_instances(self):
-        date = datetime.utcnow()
-        start_date = datetime.utcnow() - timedelta(days=2)
+    with dag_maker("test_try_adopt_task_instances"):
+        BaseOperator(task_id="task_1", start_date=start_date)
+        BaseOperator(task_id="task_2", start_date=start_date)
+        BaseOperator(task_id="task_3", start_date=start_date)
 
-        with DAG("test_try_adopt_task_instances"):
-            task_1 = BaseOperator(task_id="task_1", start_date=start_date)
-            task_2 = BaseOperator(task_id="task_2", start_date=start_date)
-            task_3 = BaseOperator(task_id="task_3", start_date=start_date)
+    dagrun = dag_maker.create_dagrun(execution_date=date)
+    tis = dagrun.task_instances
 
-        key1 = TaskInstance(task=task_1, execution_date=date)
-        key2 = TaskInstance(task=task_2, execution_date=date)
-        key3 = TaskInstance(task=task_3, execution_date=date)
-        tis = [key1, key2, key3]
-        assert BaseExecutor().try_adopt_task_instances(tis) == tis
+    assert [ti.task_id for ti in tis] == ["task_1", "task_2", "task_3"]
+    assert BaseExecutor().try_adopt_task_instances(tis) == tis
diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py
index 23a338b..498c8ce 100644
--- a/tests/executors/test_celery_executor.py
+++ b/tests/executors/test_celery_executor.py
@@ -184,7 +184,7 @@ class TestCeleryExecutor(unittest.TestCase):
                 'command',
                 1,
                 None,
-                SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())),
+                SimpleTaskInstance(ti=TaskInstance(task=task, run_id=None)),
             )
             key = ('fail', 'fake_simple_ti', when, 0)
             executor.queued_tasks[key] = value_tuple
@@ -217,7 +217,7 @@ class TestCeleryExecutor(unittest.TestCase):
                 'command',
                 1,
                 None,
-                SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())),
+                SimpleTaskInstance(ti=TaskInstance(task=task, run_id=None)),
             )
             key = ('fail', 'fake_simple_ti', when, 0)
             executor.queued_tasks[key] = value_tuple
@@ -300,13 +300,12 @@ class TestCeleryExecutor(unittest.TestCase):
 
     @pytest.mark.backend("mysql", "postgres")
     def test_try_adopt_task_instances_none(self):
-        date = datetime.utcnow()
         start_date = datetime.utcnow() - timedelta(days=2)
 
         with DAG("test_try_adopt_task_instances_none"):
             task_1 = BaseOperator(task_id="task_1", start_date=start_date)
 
-        key1 = TaskInstance(task=task_1, execution_date=date)
+        key1 = TaskInstance(task=task_1, run_id=None)
         tis = [key1]
         executor = celery_executor.CeleryExecutor()
 
@@ -314,7 +313,6 @@ class TestCeleryExecutor(unittest.TestCase):
 
     @pytest.mark.backend("mysql", "postgres")
     def test_try_adopt_task_instances(self):
-        exec_date = timezone.utcnow() - timedelta(minutes=2)
         start_date = timezone.utcnow() - timedelta(days=2)
         queued_dttm = timezone.utcnow() - timedelta(minutes=1)
 
@@ -324,11 +322,11 @@ class TestCeleryExecutor(unittest.TestCase):
             task_1 = BaseOperator(task_id="task_1", start_date=start_date)
             task_2 = BaseOperator(task_id="task_2", start_date=start_date)
 
-        ti1 = TaskInstance(task=task_1, execution_date=exec_date)
+        ti1 = TaskInstance(task=task_1, run_id=None)
         ti1.external_executor_id = '231'
         ti1.queued_dttm = queued_dttm
         ti1.state = State.QUEUED
-        ti2 = TaskInstance(task=task_2, execution_date=exec_date)
+        ti2 = TaskInstance(task=task_2, run_id=None)
         ti2.external_executor_id = '232'
         ti2.queued_dttm = queued_dttm
         ti2.state = State.QUEUED
@@ -341,8 +339,8 @@ class TestCeleryExecutor(unittest.TestCase):
 
         not_adopted_tis = executor.try_adopt_task_instances(tis)
 
-        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number)
-        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number)
+        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, None, try_number)
+        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, None, try_number)
         assert executor.running == {key_1, key_2}
         assert dict(executor.adopted_task_timeouts) == {
             key_1: queued_dttm + executor.task_adoption_timeout,
@@ -353,7 +351,6 @@ class TestCeleryExecutor(unittest.TestCase):
 
     @pytest.mark.backend("mysql", "postgres")
     def test_check_for_stalled_adopted_tasks(self):
-        exec_date = timezone.utcnow() - timedelta(minutes=40)
         start_date = timezone.utcnow() - timedelta(days=2)
         queued_dttm = timezone.utcnow() - timedelta(minutes=30)
 
@@ -363,8 +360,8 @@ class TestCeleryExecutor(unittest.TestCase):
             task_1 = BaseOperator(task_id="task_1", start_date=start_date)
             task_2 = BaseOperator(task_id="task_2", start_date=start_date)
 
-        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number)
-        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number)
+        key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number)
+        key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number)
 
         executor = celery_executor.CeleryExecutor()
         executor.adopted_task_timeouts = {
diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py
index e48d6ce..025b956 100644
--- a/tests/executors/test_kubernetes_executor.py
+++ b/tests/executors/test_kubernetes_executor.py
@@ -42,7 +42,7 @@ try:
     )
     from airflow.kubernetes import pod_generator
     from airflow.kubernetes.kubernetes_helper_functions import annotations_to_key
-    from airflow.kubernetes.pod_generator import PodGenerator, datetime_to_label_safe_datestring
+    from airflow.kubernetes.pod_generator import PodGenerator
     from airflow.utils.state import State
 except ImportError:
     AirflowKubernetesScheduler = None  # type: ignore
@@ -226,7 +226,7 @@ class TestKubernetesExecutor(unittest.TestCase):
             # Execute a task while the Api Throws errors
             try_number = 1
             kubernetes_executor.execute_async(
-                key=('dag', 'task', datetime.utcnow(), try_number),
+                key=('dag', 'task', 'run_id', try_number),
                 queue=None,
                 command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
             )
@@ -298,10 +298,8 @@ class TestKubernetesExecutor(unittest.TestCase):
             assert executor.event_buffer == {}
             assert executor.task_queue.empty()
 
-            execution_date = datetime.utcnow()
-
             executor.execute_async(
-                key=('dag', 'task', execution_date, 1),
+                key=('dag', 'task', 'run_id', 1),
                 queue=None,
                 command=['airflow', 'tasks', 'run', 'true', 'some_parameter'],
                 executor_config={
@@ -333,7 +331,7 @@ class TestKubernetesExecutor(unittest.TestCase):
                         namespace="default",
                         annotations={
                             'dag_id': 'dag',
-                            'execution_date': execution_date.isoformat(),
+                            'run_id': 'run_id',
                             'task_id': 'task',
                             'try_number': '1',
                         },
@@ -341,7 +339,7 @@ class TestKubernetesExecutor(unittest.TestCase):
                             'airflow-worker': '5',
                             'airflow_version': mock.ANY,
                             'dag_id': 'dag',
-                            'execution_date': datetime_to_label_safe_datestring(execution_date),
+                            'run_id': 'run_id',
                             'kubernetes_executor': 'True',
                             'mylabel': 'foo',
                             'release': 'stable',
@@ -370,7 +368,7 @@ class TestKubernetesExecutor(unittest.TestCase):
     def test_change_state_running(self, mock_get_kube_client, mock_kubernetes_job_watcher):
         executor = self.kubernetes_executor
         executor.start()
-        key = ('dag_id', 'task_id', 'ex_time', 'try_number1')
+        key = ('dag_id', 'task_id', 'run_id', 'try_number1')
         executor._change_state(key, State.RUNNING, 'pod_id', 'default')
         assert executor.event_buffer[key][0] == State.RUNNING
 
@@ -380,8 +378,7 @@ class TestKubernetesExecutor(unittest.TestCase):
     def test_change_state_success(self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher):
         executor = self.kubernetes_executor
         executor.start()
-        test_time = timezone.utcnow()
-        key = ('dag_id', 'task_id', test_time, 'try_number2')
+        key = ('dag_id', 'task_id', 'run_id', 'try_number2')
         executor._change_state(key, State.SUCCESS, 'pod_id', 'default')
         assert executor.event_buffer[key][0] == State.SUCCESS
         mock_delete_pod.assert_called_once_with('pod_id', 'default')
@@ -396,8 +393,7 @@ class TestKubernetesExecutor(unittest.TestCase):
         executor.kube_config.delete_worker_pods = False
         executor.kube_config.delete_worker_pods_on_failure = False
         executor.start()
-        test_time = timezone.utcnow()
-        key = ('dag_id', 'task_id', test_time, 'try_number3')
+        key = ('dag_id', 'task_id', 'run_id', 'try_number3')
         executor._change_state(key, State.FAILED, 'pod_id', 'default')
         assert executor.event_buffer[key][0] == State.FAILED
         mock_delete_pod.assert_not_called()
@@ -408,13 +404,12 @@ class TestKubernetesExecutor(unittest.TestCase):
     def test_change_state_skip_pod_deletion(
         self, mock_delete_pod, mock_get_kube_client, mock_kubernetes_job_watcher
     ):
-        test_time = timezone.utcnow()
         executor = self.kubernetes_executor
         executor.kube_config.delete_worker_pods = False
         executor.kube_config.delete_worker_pods_on_failure = False
 
         executor.start()
-        key = ('dag_id', 'task_id', test_time, 'try_number2')
+        key = ('dag_id', 'task_id', 'run_id', 'try_number2')
         executor._change_state(key, State.SUCCESS, 'pod_id', 'default')
         assert executor.event_buffer[key][0] == State.SUCCESS
         mock_delete_pod.assert_not_called()
@@ -429,7 +424,7 @@ class TestKubernetesExecutor(unittest.TestCase):
         executor.kube_config.delete_worker_pods_on_failure = True
 
         executor.start()
-        key = ('dag_id', 'task_id', 'ex_time', 'try_number2')
+        key = ('dag_id', 'task_id', 'run_id', 'try_number2')
         executor._change_state(key, State.FAILED, 'pod_id', 'test-namespace')
         assert executor.event_buffer[key][0] == State.FAILED
         mock_delete_pod.assert_called_once_with('pod_id', 'test-namespace')
@@ -442,7 +437,7 @@ class TestKubernetesExecutor(unittest.TestCase):
         ti_key = annotations_to_key(
             {
                 'dag_id': 'dag',
-                'execution_date': datetime.utcnow().isoformat(),
+                'run_id': 'run_id',
                 'task_id': 'task',
                 'try_number': '1',
             }
@@ -525,7 +520,7 @@ class TestKubernetesExecutor(unittest.TestCase):
         executor.scheduler_job_id = "modified"
         annotations = {
             'dag_id': 'dag',
-            'execution_date': datetime.utcnow().isoformat(),
+            'run_id': 'run_id',
             'task_id': 'task',
             'try_number': '1',
         }
@@ -566,7 +561,7 @@ class TestKubernetesExecutor(unittest.TestCase):
                 labels={"airflow-worker": "bar"},
                 annotations={
                     'dag_id': 'dag',
-                    'execution_date': datetime.utcnow().isoformat(),
+                    'run_id': 'run_id',
                     'task_id': 'task',
                     'try_number': '1',
                 },
@@ -680,8 +675,9 @@ class TestKubernetesJobWatcher(unittest.TestCase):
         self.core_annotations = {
             "dag_id": "dag",
             "task_id": "task",
-            "execution_date": "dt",
+            "run_id": "run_id",
             "try_number": "1",
+            "execution_date": None,
         }
         self.pod = k8s.V1Pod(
             metadata=k8s.V1ObjectMeta(
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 31826a8..32ae794 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -24,7 +24,6 @@ import threading
 from unittest.mock import patch
 
 import pytest
-import sqlalchemy
 
 from airflow import settings
 from airflow.cli import cli_parser
@@ -191,7 +190,7 @@ class TestBackfillJob:
             ("run_this_last", end_date),
         ]
         assert [
-            ((dag.dag_id, task_id, when, 1), (State.SUCCESS, None))
+            ((dag.dag_id, task_id, f'backfill__{when.isoformat()}', 1), (State.SUCCESS, None))
             for (task_id, when) in expected_execution_order
         ] == executor.sorted_tasks
 
@@ -268,7 +267,7 @@ class TestBackfillJob:
 
         job.run()
         assert [
-            ((dag_id, task_id, DEFAULT_DATE, 1), (State.SUCCESS, None))
+            ((dag_id, task_id, f'backfill__{DEFAULT_DATE.isoformat()}', 1), (State.SUCCESS, None))
             for task_id in expected_execution_order
         ] == executor.sorted_tasks
 
@@ -707,13 +706,13 @@ class TestBackfillJob:
             },
         ) as dag:
             task1 = DummyOperator(task_id="task1")
-        dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun()
 
         executor = MockExecutor(parallelism=16)
         executor.mock_task_results[
-            TaskInstanceKey(dag.dag_id, task1.task_id, DEFAULT_DATE, try_number=1)
+            TaskInstanceKey(dag.dag_id, task1.task_id, dr.run_id, try_number=1)
         ] = State.UP_FOR_RETRY
-        executor.mock_task_fail(dag.dag_id, task1.task_id, DEFAULT_DATE, try_number=2)
+        executor.mock_task_fail(dag.dag_id, task1.task_id, dr.run_id, try_number=2)
         job = BackfillJob(
             dag=dag,
             executor=executor,
@@ -739,7 +738,8 @@ class TestBackfillJob:
             op1.set_downstream(op3)
             op4.set_downstream(op5)
             op3.set_downstream(op4)
-        dag_maker.create_dagrun()
+        runid0 = f'backfill__{DEFAULT_DATE.isoformat()}'
+        dag_maker.create_dagrun(run_id=runid0)
 
         executor = MockExecutor(parallelism=16)
         job = BackfillJob(
@@ -750,25 +750,24 @@ class TestBackfillJob:
         )
         job.run()
 
-        date0 = DEFAULT_DATE
-        date1 = date0 + datetime.timedelta(days=1)
-        date2 = date1 + datetime.timedelta(days=1)
+        runid1 = f'backfill__{(DEFAULT_DATE + datetime.timedelta(days=1)).isoformat()}'
+        runid2 = f'backfill__{(DEFAULT_DATE + datetime.timedelta(days=2)).isoformat()}'
 
         # test executor history keeps a list
         history = executor.history
 
         assert [sorted(item[-1].key[1:3] for item in batch) for batch in history] == [
             [
-                ('leave1', date0),
-                ('leave1', date1),
-                ('leave1', date2),
-                ('leave2', date0),
-                ('leave2', date1),
-                ('leave2', date2),
+                ('leave1', runid0),
+                ('leave1', runid1),
+                ('leave1', runid2),
+                ('leave2', runid0),
+                ('leave2', runid1),
+                ('leave2', runid2),
             ],
-            [('upstream_level_1', date0), ('upstream_level_1', date1), ('upstream_level_1', date2)],
-            [('upstream_level_2', date0), ('upstream_level_2', date1), ('upstream_level_2', date2)],
-            [('upstream_level_3', date0), ('upstream_level_3', date1), ('upstream_level_3', date2)],
+            [('upstream_level_1', runid0), ('upstream_level_1', runid1), ('upstream_level_1', runid2)],
+            [('upstream_level_2', runid0), ('upstream_level_2', runid1), ('upstream_level_2', runid2)],
+            [('upstream_level_3', runid0), ('upstream_level_3', runid1), ('upstream_level_3', runid2)],
         ]
 
     def test_backfill_pooled_tasks(self):
@@ -1045,13 +1044,6 @@ class TestBackfillJob:
         job = BackfillJob(dag=sub_dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor)
         job.run()
 
-        with pytest.raises(sqlalchemy.orm.exc.NoResultFound):
-            dr.refresh_from_db()
-        # the run_id should have changed, so a refresh won't work
-        drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
-        dr = drs[0]
-
-        assert DagRun.generate_run_id(DagRunType.BACKFILL_JOB, DEFAULT_DATE) == dr.run_id
         for ti in dr.get_task_instances():
             if ti.task_id == 'leave1' or ti.task_id == 'leave2':
                 assert State.SUCCESS == ti.state
@@ -1097,11 +1089,7 @@ class TestBackfillJob:
         with pytest.raises(AirflowException, match='Some task instances failed'):
             job.run()
 
-        with pytest.raises(sqlalchemy.orm.exc.NoResultFound):
-            dr.refresh_from_db()
-        # the run_id should have changed, so a refresh won't work
-        drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE)
-        dr = drs[0]
+        dr.refresh_from_db()
 
         assert dr.state == State.FAILED
 
@@ -1210,18 +1198,22 @@ class TestBackfillJob:
         dag = self.dagbag.get_dag('example_subdag_operator')
         subdag = dag.get_task('section-1').subdag
 
+        session = settings.Session()
         executor = MockExecutor()
         job = BackfillJob(
             dag=subdag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True
         )
+        dr = DagRun(
+            dag_id=subdag.dag_id, execution_date=DEFAULT_DATE, run_id="test", run_type=DagRunType.BACKFILL_JOB
+        )
+        session.add(dr)
 
         removed_task_ti = TI(
-            task=DummyOperator(task_id='removed_task'), execution_date=DEFAULT_DATE, state=State.REMOVED
+            task=DummyOperator(task_id='removed_task'), run_id=dr.run_id, state=State.REMOVED
         )
         removed_task_ti.dag_id = subdag.dag_id
+        dr.task_instances.append(removed_task_ti)
 
-        session = settings.Session()
-        session.merge(removed_task_ti)
         session.commit()
 
         with timeout(seconds=30):
@@ -1378,8 +1370,9 @@ class TestBackfillJob:
         session = settings.Session()
         tis = (
             session.query(TI)
+            .join(TI.dag_run)
             .filter(TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy')
-            .order_by(TI.execution_date)
+            .order_by(DagRun.execution_date)
             .all()
         )
 
@@ -1397,7 +1390,7 @@ class TestBackfillJob:
         states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE]
 
         tasks = []
-        with dag_maker(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily") as dag:
+        with dag_maker(dag_id=prefix) as dag:
             for i in range(len(states)):
                 task_id = f"{prefix}_task_{i}"
                 task = DummyOperator(task_id=task_id)
@@ -1452,7 +1445,7 @@ class TestBackfillJob:
         for state, ti in zip(states, dr2_tis):
             assert state == ti.state
 
-    def test_reset_orphaned_tasks_specified_dagrun(self, dag_maker):
+    def test_reset_orphaned_tasks_specified_dagrun(self, session, dag_maker):
         """Try to reset when we specify a dagrun and ensure nothing else is."""
         dag_id = 'test_reset_orphaned_tasks_specified_dagrun'
         task_id = dag_id + '_task'
@@ -1460,14 +1453,14 @@ class TestBackfillJob:
             dag_id=dag_id,
             start_date=DEFAULT_DATE,
             schedule_interval='@daily',
+            session=session,
         ) as dag:
             DummyOperator(task_id=task_id, dag=dag)
 
         job = BackfillJob(dag=dag)
-        session = settings.Session()
         # make two dagruns, only reset for one
         dr1 = dag_maker.create_dagrun(state=State.SUCCESS)
-        dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING)
+        dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING, session=session)
         ti1 = dr1.get_task_instances(session=session)[0]
         ti2 = dr2.get_task_instances(session=session)[0]
         ti1.state = State.SCHEDULED
@@ -1477,7 +1470,7 @@ class TestBackfillJob:
         session.merge(ti2)
         session.merge(dr1)
         session.merge(dr2)
-        session.commit()
+        session.flush()
 
         num_reset_tis = job.reset_state_for_orphaned_tasks(filter_by_dag_run=dr2, session=session)
         assert 1 == num_reset_tis
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 34b8526..1574fd6 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -244,7 +244,7 @@ class TestLocalTaskJob:
             dag = self.dagbag.get_dag(dag_id)
             task = dag.get_task(task_id)
 
-            dag.create_dagrun(
+            dr = dag.create_dagrun(
                 run_id="test_heartbeat_failed_fast_run",
                 state=State.RUNNING,
                 execution_date=DEFAULT_DATE,
@@ -252,9 +252,9 @@ class TestLocalTaskJob:
                 session=session,
             )
 
-            ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
-            ti.refresh_from_db()
-            ti.state = State.RUNNING
+            ti = dr.task_instances[0]
+            ti.refresh_from_task(task)
+            ti.state = State.QUEUED
             ti.hostname = get_hostname()
             ti.pid = 1
             session.commit()
@@ -291,11 +291,12 @@ class TestLocalTaskJob:
             time.sleep(10)
 
         with dag_maker('test_mark_success'):
-            task1 = PythonOperator(task_id="task1", python_callable=task_function)
-        dag_maker.create_dagrun()
+            task = PythonOperator(task_id="task1", python_callable=task_function)
+        dr = dag_maker.create_dagrun()
+
+        ti = dr.task_instances[0]
+        ti.refresh_from_task(task)
 
-        ti = TaskInstance(task=task1, execution_date=DEFAULT_DATE)
-        ti.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
 
         def dummy_return_code(*args, **kwargs):
@@ -335,7 +336,7 @@ class TestLocalTaskJob:
         session.merge(ti)
         session.commit()
 
-        ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+        ti_run = TaskInstance(task=task, run_id=dr.run_id)
         ti_run.refresh_from_db()
         job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor())
         with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method:
@@ -671,14 +672,14 @@ class TestLocalTaskJob:
 
             dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)
 
-            task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A'])
+            task_instance_a = TaskInstance(task_a, run_id=dag_run.run_id, state=init_state['A'])
 
-            task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B'])
+            task_instance_b = TaskInstance(task_b, run_id=dag_run.run_id, state=init_state['B'])
 
-            task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C'])
+            task_instance_c = TaskInstance(task_c, run_id=dag_run.run_id, state=init_state['C'])
 
             if 'D' in init_state:
-                task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D'])
+                task_instance_d = TaskInstance(task_d, run_id=dag_run.run_id, state=init_state['D'])
                 session.merge(task_instance_d)
 
             session.merge(task_instance_a)
@@ -731,8 +732,9 @@ class TestLocalTaskJob:
                 retries=1,
                 on_retry_callback=retry_callback,
             )
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
-        ti.refresh_from_db()
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.refresh_from_task(task)
         job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
         settings.engine.dispose()
         with timeout(10):
@@ -814,20 +816,20 @@ def clean_db_helper():
 
 
 @pytest.mark.usefixtures("clean_db_helper")
-class TestLocalTaskJobPerformance:
-    @pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]])  # type: ignore
-    @mock.patch("airflow.jobs.local_task_job.get_task_runner")
-    def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes, dag_maker):
-        unique_prefix = str(uuid.uuid4())
-        with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'):
-            task = DummyOperator(task_id='test_state_succeeded1')
+@pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]])
+@mock.patch("airflow.jobs.local_task_job.get_task_runner")
+def test_number_of_queries_single_loop(mock_get_task_runner, return_codes, dag_maker):
+    mock_get_task_runner.return_value.return_code.side_effects = return_codes
 
-        dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE)
+    unique_prefix = str(uuid.uuid4())
+    with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'):
+        task = DummyOperator(task_id='test_state_succeeded1')
 
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
+    dr = dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE)
 
-        mock_get_task_runner.return_value.return_code.side_effects = return_codes
+    ti = dr.task_instances[0]
+    ti.refresh_from_task(task)
 
-        job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
-        with assert_queries_count(18):
-            job.run()
+    job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
+    with assert_queries_count(25):
+        job.run()
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index 0be4478..dfcb67e 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -184,15 +184,14 @@ class TestSchedulerJob:
     @mock.patch('airflow.jobs.scheduler_job.Stats.incr')
     def test_process_executor_events(self, mock_stats_incr, mock_task_callback, dag_maker):
         dag_id = "test_process_executor_events"
-        dag_id2 = "test_process_executor_events_2"
         task_id_1 = 'dummy_task'
 
         with dag_maker(dag_id=dag_id, fileloc='/test_path1/'):
             task1 = DummyOperator(task_id=task_id_1)
-        with dag_maker(dag_id=dag_id2, fileloc='/test_path1/'):
-            DummyOperator(task_id=task_id_1)
+        ti1 = dag_maker.create_dagrun().get_task_instance(task1.task_id)
 
         mock_stats_incr.reset_mock()
+
         executor = MockExecutor(do_update=False)
         task_callback = mock.MagicMock()
         mock_task_callback.return_value = task_callback
@@ -201,7 +200,6 @@ class TestSchedulerJob:
 
         session = settings.Session()
 
-        ti1 = TaskInstance(task1, DEFAULT_DATE)
         ti1.state = State.QUEUED
         session.merge(ti1)
         session.commit()
@@ -215,7 +213,7 @@ class TestSchedulerJob:
             full_filepath='/test_path1/',
             simple_task_instance=mock.ANY,
             msg='Executor reports task instance '
-            '<TaskInstance: test_process_executor_events.dummy_task 2016-01-01 00:00:00+00:00 [queued]> '
+            '<TaskInstance: test_process_executor_events.dummy_task test [queued]> '
             'finished (failed) although the task says its queued. (Info: None) '
             'Was the task killed externally?',
         )
@@ -235,22 +233,23 @@ class TestSchedulerJob:
         mock_stats_incr.assert_called_once_with('scheduler.tasks.killed_externally')
 
     def test_process_executor_events_uses_inmemory_try_number(self, dag_maker):
-        execution_date = DEFAULT_DATE
         dag_id = "dag_id"
         task_id = "task_id"
         try_number = 42
 
+        with dag_maker(dag_id=dag_id):
+            DummyOperator(task_id=task_id)
+
+        dr = dag_maker.create_dagrun()
+
         executor = MagicMock()
         self.scheduler_job = SchedulerJob(executor=executor)
         self.scheduler_job.processor_agent = MagicMock()
-        event_buffer = {TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None)}
+        event_buffer = {TaskInstanceKey(dag_id, task_id, dr.run_id, try_number): (State.SUCCESS, None)}
         executor.get_event_buffer.return_value = event_buffer
 
-        with dag_maker(dag_id=dag_id):
-            task = DummyOperator(task_id=task_id)
-
         with create_session() as session:
-            ti = TaskInstance(task, DEFAULT_DATE)
+            ti = dr.task_instances[0]
             ti.state = State.SUCCESS
             session.merge(ti)
 
@@ -259,55 +258,25 @@ class TestSchedulerJob:
         # task instance key
         assert event_buffer == {}
 
-    def test_execute_task_instances_is_paused_wont_execute(self, dag_maker):
+    def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker):
         dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute'
         task_id_1 = 'dummy_task'
 
-        with dag_maker(dag_id=dag_id) as dag:
-            task1 = DummyOperator(task_id=task_id_1)
+        with dag_maker(dag_id=dag_id, session=session) as dag:
+            DummyOperator(task_id=task_id_1)
         assert isinstance(dag, SerializedDAG)
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB)
-        ti1 = TaskInstance(task1, DEFAULT_DATE)
+        (ti1,) = dr1.task_instances
         ti1.state = State.SCHEDULED
-        session.merge(ti1)
-        session.merge(dr1)
-        session.flush()
 
         self.scheduler_job._critical_section_execute_task_instances(session)
-        session.flush()
-        ti1.refresh_from_db()
+        ti1.refresh_from_db(session=session)
         assert State.SCHEDULED == ti1.state
         session.rollback()
 
-    def test_execute_task_instances_no_dagrun_task_will_execute(self, dag_maker):
-        """
-        Tests that tasks without dagrun still get executed.
-        """
-        dag_id = 'SchedulerJobTest.test_execute_task_instances_no_dagrun_task_will_execute'
-        task_id_1 = 'dummy_task'
-
-        with dag_maker(dag_id=dag_id):
-            task1 = DummyOperator(task_id=task_id_1)
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
-
-        ti1 = TaskInstance(task1, DEFAULT_DATE)
-        ti1.state = State.SCHEDULED
-        ti1.execution_date = ti1.execution_date + datetime.timedelta(days=1)
-        session.merge(ti1)
-        session.flush()
-
-        self.scheduler_job._critical_section_execute_task_instances(session)
-        session.flush()
-        ti1.refresh_from_db()
-        assert State.QUEUED == ti1.state
-        session.rollback()
-
     def test_execute_task_instances_backfill_tasks_wont_execute(self, dag_maker):
         """
         Tests that backfill tasks won't get executed.
@@ -323,11 +292,10 @@ class TestSchedulerJob:
 
         dr1 = dag_maker.create_dagrun(run_type=DagRunType.BACKFILL_JOB)
 
-        ti1 = TaskInstance(task1, dr1.execution_date)
+        ti1 = TaskInstance(task1, run_id=dr1.run_id)
         ti1.refresh_from_db()
         ti1.state = State.SCHEDULED
         session.merge(ti1)
-        session.merge(dr1)
         session.flush()
 
         assert dr1.is_backfill
@@ -338,8 +306,8 @@ class TestSchedulerJob:
         assert State.SCHEDULED == ti1.state
         session.rollback()
 
-    def test_find_executable_task_instances_backfill_nodagrun(self, dag_maker):
-        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun'
+    def test_find_executable_task_instances_backfill(self, dag_maker):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill'
         task_id_1 = 'dummy'
         with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag:
             task1 = DummyOperator(task_id=task_id_1)
@@ -354,24 +322,20 @@ class TestSchedulerJob:
             state=State.RUNNING,
         )
 
-        ti_no_dagrun = TaskInstance(task1, DEFAULT_DATE - datetime.timedelta(days=1))
-        ti_backfill = TaskInstance(task1, dr2.execution_date)
-        ti_with_dagrun = TaskInstance(task1, dr1.execution_date)
+        ti_backfill = dr2.get_task_instance(task1.task_id)
+        ti_with_dagrun = dr1.get_task_instance(task1.task_id)
         # ti_with_paused
-        ti_no_dagrun.state = State.SCHEDULED
         ti_backfill.state = State.SCHEDULED
         ti_with_dagrun.state = State.SCHEDULED
 
         session.merge(dr2)
-        session.merge(ti_no_dagrun)
         session.merge(ti_backfill)
         session.merge(ti_with_dagrun)
         session.flush()
 
         res = self.scheduler_job._executable_task_instances_to_queued(max_tis=32, session=session)
-        assert 2 == len(res)
+        assert 1 == len(res)
         res_keys = map(lambda x: x.key, res)
-        assert ti_no_dagrun.key in res_keys
         assert ti_with_dagrun.key in res_keys
         session.rollback()
 
@@ -379,26 +343,21 @@ class TestSchedulerJob:
         dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool'
         task_id_1 = 'dummy'
         task_id_2 = 'dummydummy'
-        with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag:
-            task1 = DummyOperator(task_id=task_id_1, pool='a')
-            task2 = DummyOperator(task_id=task_id_2, pool='b')
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag:
+            DummyOperator(task_id=task_id_1, pool='a')
+            DummyOperator(task_id=task_id_2, pool='b')
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         dr1 = dag_maker.create_dagrun()
-        dr2 = dag.create_dagrun(
+        dr2 = dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=dag.following_schedule(dr1.execution_date),
             state=State.RUNNING,
         )
 
-        tis = [
-            TaskInstance(task1, dr1.execution_date),
-            TaskInstance(task2, dr1.execution_date),
-            TaskInstance(task1, dr2.execution_date),
-            TaskInstance(task2, dr2.execution_date),
-        ]
+        tis = dr1.task_instances + dr2.task_instances
         for ti in tis:
             ti.state = State.SCHEDULED
             session.merge(ti)
@@ -428,21 +387,20 @@ class TestSchedulerJob:
         dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-a'
         dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date-b'
         task_id = 'task-a'
-        with dag_maker(dag_id=dag_id_1, max_active_tasks=16):
-            dag1_task = DummyOperator(task_id=task_id)
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id)
         dr1 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1))
 
-        with dag_maker(dag_id=dag_id_2, max_active_tasks=16):
-            dag2_task = DummyOperator(task_id=task_id)
+        with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id)
         dr2 = dag_maker.create_dagrun()
 
+        dr1 = session.merge(dr1, load=False)
+
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
-        tis = [
-            TaskInstance(dag1_task, dr1.execution_date),
-            TaskInstance(dag2_task, dr2.execution_date),
-        ]
+        tis = dr1.task_instances + dr2.task_instances
         for ti in tis:
             ti.state = State.SCHEDULED
             session.merge(ti)
@@ -457,21 +415,20 @@ class TestSchedulerJob:
         dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-a'
         dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_priority-b'
         task_id = 'task-a'
-        with dag_maker(dag_id=dag_id_1, max_active_tasks=16):
-            dag1_task = DummyOperator(task_id=task_id, priority_weight=1)
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id, priority_weight=1)
         dr1 = dag_maker.create_dagrun()
 
-        with dag_maker(dag_id=dag_id_2, max_active_tasks=16):
-            dag2_task = DummyOperator(task_id=task_id, priority_weight=4)
+        with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id, priority_weight=4)
         dr2 = dag_maker.create_dagrun()
 
+        dr1 = session.merge(dr1, load=False)
+
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
-        tis = [
-            TaskInstance(dag1_task, dr1.execution_date),
-            TaskInstance(dag2_task, dr2.execution_date),
-        ]
+        tis = dr1.task_instances + dr2.task_instances
         for ti in tis:
             ti.state = State.SCHEDULED
             session.merge(ti)
@@ -486,21 +443,19 @@ class TestSchedulerJob:
         dag_id_1 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-a'
         dag_id_2 = 'SchedulerJobTest.test_find_executable_task_instances_order_execution_date_and_priority-b'
         task_id = 'task-a'
-        with dag_maker(dag_id=dag_id_1, max_active_tasks=16):
-            dag1_task = DummyOperator(task_id=task_id, priority_weight=1)
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id_1, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id, priority_weight=1)
         dr1 = dag_maker.create_dagrun()
 
-        with dag_maker(dag_id=dag_id_2, max_active_tasks=16):
-            dag2_task = DummyOperator(task_id=task_id, priority_weight=4)
+        with dag_maker(dag_id=dag_id_2, max_active_tasks=16, session=session):
+            DummyOperator(task_id=task_id, priority_weight=4)
         dr2 = dag_maker.create_dagrun(execution_date=DEFAULT_DATE + timedelta(hours=1))
 
+        dr1 = session.merge(dr1, load=False)
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
-        tis = [
-            TaskInstance(dag1_task, dr1.execution_date),
-            TaskInstance(dag2_task, dr2.execution_date),
-        ]
+        tis = dr1.task_instances + dr2.task_instances
         for ti in tis:
             ti.state = State.SCHEDULED
             session.merge(ti)
@@ -530,13 +485,11 @@ class TestSchedulerJob:
             state=State.RUNNING,
         )
 
-        ti1 = TaskInstance(task=op1, execution_date=dr1.execution_date)
-        ti2 = TaskInstance(task=op2, execution_date=dr2.execution_date)
+        ti1 = dr1.get_task_instance(op1.task_id, session)
+        ti2 = dr2.get_task_instance(op2.task_id, session)
         ti1.state = State.SCHEDULED
         ti2.state = State.SCHEDULED
 
-        session.merge(ti1)
-        session.merge(ti2)
         session.flush()
 
         # Two tasks w/o pool up for execution and our default pool size is 1
@@ -544,7 +497,6 @@ class TestSchedulerJob:
         assert 1 == len(res)
 
         ti2.state = State.RUNNING
-        session.merge(ti2)
         session.flush()
 
         # One task w/o pool up for execution and one task running
@@ -556,16 +508,15 @@ class TestSchedulerJob:
 
     def test_nonexistent_pool(self, dag_maker):
         dag_id = 'SchedulerJobTest.test_nonexistent_pool'
-        task_id = 'dummy_wrong_pool'
         with dag_maker(dag_id=dag_id, max_active_tasks=16):
-            task = DummyOperator(task_id=task_id, pool="this_pool_doesnt_exist")
+            DummyOperator(task_id="dummy_wrong_pool", pool="this_pool_doesnt_exist")
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         session = settings.Session()
 
         dr = dag_maker.create_dagrun()
 
-        ti = TaskInstance(task, dr.execution_date)
+        ti = dr.task_instances[0]
         ti.state = State.SCHEDULED
         session.merge(ti)
         session.commit()
@@ -577,15 +528,14 @@ class TestSchedulerJob:
 
     def test_infinite_pool(self, dag_maker):
         dag_id = 'SchedulerJobTest.test_infinite_pool'
-        task_id = 'dummy'
         with dag_maker(dag_id=dag_id, concurrency=16):
-            task = DummyOperator(task_id=task_id, pool="infinite_pool")
+            DummyOperator(task_id="dummy", pool="infinite_pool")
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         session = settings.Session()
 
         dr = dag_maker.create_dagrun()
-        ti = TaskInstance(task, dr.execution_date)
+        ti = dr.task_instances[0]
         ti.state = State.SCHEDULED
         session.merge(ti)
         infinite_pool = Pool(pool='infinite_pool', slots=-1, description='infinite pool')
@@ -642,28 +592,27 @@ class TestSchedulerJob:
 
     def test_find_executable_task_instances_concurrency(self, dag_maker):
         dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency'
-        task_id_1 = 'dummy'
-        with dag_maker(dag_id=dag_id, max_active_tasks=2) as dag:
-            task1 = DummyOperator(task_id=task_id_1)
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id, max_active_tasks=2, session=session) as dag:
+            DummyOperator(task_id='dummy')
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         dr1 = dag_maker.create_dagrun()
-        dr2 = dag.create_dagrun(
+        dr2 = dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=dag.following_schedule(dr1.execution_date),
             state=State.RUNNING,
         )
-        dr3 = dag.create_dagrun(
+        dr3 = dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=dag.following_schedule(dr2.execution_date),
             state=State.RUNNING,
         )
 
-        ti1 = TaskInstance(task1, dr1.execution_date)
-        ti2 = TaskInstance(task1, dr2.execution_date)
-        ti3 = TaskInstance(task1, dr3.execution_date)
+        ti1 = dr1.task_instances[0]
+        ti2 = dr2.task_instances[0]
+        ti3 = dr3.task_instances[0]
         ti1.state = State.RUNNING
         ti2.state = State.SCHEDULED
         ti3.state = State.SCHEDULED
@@ -700,9 +649,9 @@ class TestSchedulerJob:
 
         dag_run = dag_maker.create_dagrun()
 
-        ti1 = TaskInstance(task1, dag_run.execution_date)
-        ti2 = TaskInstance(task2, dag_run.execution_date)
-        ti3 = TaskInstance(task3, dag_run.execution_date)
+        ti1 = dag_run.get_task_instance(task1.task_id)
+        ti2 = dag_run.get_task_instance(task2.task_id)
+        ti3 = dag_run.get_task_instance(task3.task_id)
         ti1.state = State.RUNNING
         ti2.state = State.QUEUED
         ti3.state = State.SCHEDULED
@@ -744,8 +693,8 @@ class TestSchedulerJob:
             state=State.RUNNING,
         )
 
-        ti1_1 = TaskInstance(task1, dr1.execution_date)
-        ti2 = TaskInstance(task2, dr1.execution_date)
+        ti1_1 = dr1.get_task_instance(task1.task_id)
+        ti2 = dr1.get_task_instance(task2.task_id)
 
         ti1_1.state = State.SCHEDULED
         ti2.state = State.SCHEDULED
@@ -759,7 +708,7 @@ class TestSchedulerJob:
 
         ti1_1.state = State.RUNNING
         ti2.state = State.RUNNING
-        ti1_2 = TaskInstance(task1, dr2.execution_date)
+        ti1_2 = dr2.get_task_instance(task1.task_id)
         ti1_2.state = State.SCHEDULED
         session.merge(ti1_1)
         session.merge(ti2)
@@ -771,7 +720,7 @@ class TestSchedulerJob:
         assert 1 == len(res)
 
         ti1_2.state = State.RUNNING
-        ti1_3 = TaskInstance(task1, dr3.execution_date)
+        ti1_3 = dr3.get_task_instance(task1.task_id)
         ti1_3.state = State.SCHEDULED
         session.merge(ti1_2)
         session.merge(ti1_3)
@@ -830,9 +779,9 @@ class TestSchedulerJob:
             state=State.RUNNING,
         )
 
-        ti1 = TaskInstance(task1, dr1.execution_date)
-        ti2 = TaskInstance(task1, dr2.execution_date)
-        ti3 = TaskInstance(task1, dr3.execution_date)
+        ti1 = dr1.get_task_instance(task1.task_id)
+        ti2 = dr2.get_task_instance(task1.task_id)
+        ti3 = dr3.get_task_instance(task1.task_id)
         ti1.state = State.RUNNING
         ti2.state = State.RUNNING
         ti3.state = State.RUNNING
@@ -850,18 +799,14 @@ class TestSchedulerJob:
     def test_enqueue_task_instances_with_queued_state(self, dag_maker):
         dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state'
         task_id_1 = 'dummy'
-        with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE):
+        session = settings.Session()
+        with dag_maker(dag_id=dag_id, start_date=DEFAULT_DATE, session=session):
             task1 = DummyOperator(task_id=task_id_1)
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
-
-        dag_model = dag_maker.dag_model
 
         dr1 = dag_maker.create_dagrun()
-        ti1 = TaskInstance(task1, dr1.execution_date)
-        ti1.dag_model = dag_model
-        session.merge(ti1)
+        ti1 = dr1.get_task_instance(task1.task_id, session)
 
         with patch.object(BaseExecutor, 'queue_command') as mock_queue_command:
             self.scheduler_job._enqueue_task_instances_with_queued_state([ti1])
@@ -873,49 +818,41 @@ class TestSchedulerJob:
         dag_id = 'SchedulerJobTest.test_execute_task_instances'
         task_id_1 = 'dummy_task'
         task_id_2 = 'dummy_task_nonexistent_queue'
+        session = settings.Session()
         # important that len(tasks) is less than max_active_tasks
         # because before scheduler._execute_task_instances would only
         # check the num tasks once so if max_active_tasks was 3,
         # we could execute arbitrarily many tasks in the second run
-        with dag_maker(dag_id=dag_id, max_active_tasks=3) as dag:
+        with dag_maker(dag_id=dag_id, max_active_tasks=3, session=session) as dag:
             task1 = DummyOperator(task_id=task_id_1)
             task2 = DummyOperator(task_id=task_id_2)
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         # create first dag run with 1 running and 1 queued
 
         dr1 = dag_maker.create_dagrun()
 
-        ti1 = TaskInstance(task1, dr1.execution_date)
-        ti2 = TaskInstance(task2, dr1.execution_date)
-        ti1.refresh_from_db()
-        ti2.refresh_from_db()
+        ti1 = dr1.get_task_instance(task1.task_id, session)
+        ti2 = dr1.get_task_instance(task2.task_id, session)
         ti1.state = State.RUNNING
         ti2.state = State.RUNNING
-        session.merge(ti1)
-        session.merge(ti2)
         session.flush()
 
         assert State.RUNNING == dr1.state
         assert 2 == DAG.get_num_task_instances(dag_id, dag.task_ids, states=[State.RUNNING], session=session)
 
         # create second dag run
-        dr2 = dag.create_dagrun(
+        dr2 = dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=dag.following_schedule(dr1.execution_date),
             state=State.RUNNING,
         )
-        ti3 = TaskInstance(task1, dr2.execution_date)
-        ti4 = TaskInstance(task2, dr2.execution_date)
-        ti3.refresh_from_db()
-        ti4.refresh_from_db()
+        ti3 = dr2.get_task_instance(task1.task_id, session)
+        ti4 = dr2.get_task_instance(task2.task_id, session)
         # manually set to scheduled so we can pick them up
         ti3.state = State.SCHEDULED
         ti4.state = State.SCHEDULED
-        session.merge(ti3)
-        session.merge(ti4)
         session.flush()
 
         assert State.RUNNING == dr2.state
@@ -939,36 +876,30 @@ class TestSchedulerJob:
         dag_id = 'SchedulerJobTest.test_execute_task_instances_limit'
         task_id_1 = 'dummy_task'
         task_id_2 = 'dummy_task_2'
+        session = settings.Session()
         # important that len(tasks) is less than max_active_tasks
         # because before scheduler._execute_task_instances would only
         # check the num tasks once so if max_active_tasks was 3,
         # we could execute arbitrarily many tasks in the second run
-        with dag_maker(dag_id=dag_id, max_active_tasks=16) as dag:
+        with dag_maker(dag_id=dag_id, max_active_tasks=16, session=session) as dag:
             task1 = DummyOperator(task_id=task_id_1)
             task2 = DummyOperator(task_id=task_id_2)
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         date = dag.start_date
         tis = []
         for _ in range(0, 4):
             date = dag.following_schedule(date)
-            dr = dag.create_dagrun(
+            dr = dag_maker.create_dagrun(
                 run_type=DagRunType.SCHEDULED,
                 execution_date=date,
                 state=State.RUNNING,
             )
-            ti1 = TaskInstance(task1, dr.execution_date)
-            ti2 = TaskInstance(task2, dr.execution_date)
-            tis.append(ti1)
-            tis.append(ti2)
-            ti1.refresh_from_db()
-            ti2.refresh_from_db()
+            ti1 = dr.get_task_instance(task1.task_id, session)
+            ti2 = dr.get_task_instance(task2.task_id, session)
             ti1.state = State.SCHEDULED
             ti2.state = State.SCHEDULED
-            session.merge(ti1)
-            session.merge(ti2)
             session.flush()
         self.scheduler_job.max_tis_per_query = 2
         res = self.scheduler_job._critical_section_execute_task_instances(session)
@@ -995,33 +926,27 @@ class TestSchedulerJob:
         dag_id = 'SchedulerJobTest.test_execute_task_instances_unlimited'
         task_id_1 = 'dummy_task'
         task_id_2 = 'dummy_task_2'
+        session = settings.Session()
 
-        with dag_maker(dag_id=dag_id, max_active_tasks=1024) as dag:
+        with dag_maker(dag_id=dag_id, max_active_tasks=1024, session=session) as dag:
             task1 = DummyOperator(task_id=task_id_1)
             task2 = DummyOperator(task_id=task_id_2)
 
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
 
         date = dag.start_date
-        tis = []
         for _ in range(0, 20):
             date = dag.following_schedule(date)
-            dr = dag.create_dagrun(
+            dr = dag_maker.create_dagrun(
                 run_type=DagRunType.SCHEDULED,
                 execution_date=date,
                 state=State.RUNNING,
             )
-            ti1 = TaskInstance(task1, dr.execution_date)
-            ti2 = TaskInstance(task2, dr.execution_date)
-            tis.append(ti1)
-            tis.append(ti2)
-            ti1.refresh_from_db()
-            ti2.refresh_from_db()
+            date = dag.following_schedule(date)
+            ti1 = dr.get_task_instance(task1.task_id, session)
+            ti2 = dr.get_task_instance(task2.task_id, session)
             ti1.state = State.SCHEDULED
             ti2.state = State.SCHEDULED
-            session.merge(ti1)
-            session.merge(ti2)
             session.flush()
         self.scheduler_job.max_tis_per_query = 0
         self.scheduler_job.executor = MagicMock(slots_available=36)
@@ -1031,86 +956,6 @@ class TestSchedulerJob:
         assert res == 36
         session.rollback()
 
-    def test_change_state_for_tis_without_dagrun(self, dag_maker):
-        with dag_maker(dag_id='test_change_state_for_tis_without_dagrun'):
-            DummyOperator(task_id='dummy')
-            DummyOperator(task_id='dummy_b')
-        dr1 = dag_maker.create_dagrun()
-
-        with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_dont_change'):
-            DummyOperator(task_id='dummy')
-        dr2 = dag_maker.create_dagrun()
-
-        # Using dag_maker for below dag will create a dagrun and we don't want a dagrun
-        with dag_maker(dag_id='test_change_state_for_tis_without_dagrun_no_dagrun') as dag3:
-            DummyOperator(task_id='dummy')
-
-        session = settings.Session()
-
-        ti1a = dr1.get_task_instance(task_id='dummy', session=session)
-        ti1a.state = State.SCHEDULED
-        ti1b = dr1.get_task_instance(task_id='dummy_b', session=session)
-        ti1b.state = State.SUCCESS
-        session.commit()
-
-        ti2 = dr2.get_task_instance(task_id='dummy', session=session)
-        ti2.state = State.SCHEDULED
-        session.commit()
-
-        ti3 = TaskInstance(dag3.get_task('dummy'), DEFAULT_DATE)
-        ti3.state = State.SCHEDULED
-        session.merge(ti3)
-        session.commit()
-
-        self.scheduler_job = SchedulerJob(num_runs=0)
-        self.scheduler_job.dagbag.collect_dags_from_db()
-
-        self.scheduler_job._change_state_for_tis_without_dagrun(
-            old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session
-        )
-
-        ti1a = dr1.get_task_instance(task_id='dummy', session=session)
-        ti1a.refresh_from_db(session=session)
-        assert ti1a.state == State.SCHEDULED
-
-        ti1b = dr1.get_task_instance(task_id='dummy_b', session=session)
-        ti1b.refresh_from_db(session=session)
-        assert ti1b.state == State.SUCCESS
-
-        ti2 = dr2.get_task_instance(task_id='dummy', session=session)
-        ti2.refresh_from_db(session=session)
-        assert ti2.state == State.SCHEDULED
-
-        ti3.refresh_from_db(session=session)
-        assert ti3.state == State.NONE
-        assert ti3.start_date is not None
-        assert ti3.end_date is None
-        assert ti3.duration is None
-
-        dr1.refresh_from_db(session=session)
-        dr1.state = State.FAILED
-
-        # Push the changes to DB
-        session.merge(dr1)
-        session.commit()
-
-        self.scheduler_job._change_state_for_tis_without_dagrun(
-            old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session
-        )
-
-        # Clear the session objects
-        session.expunge_all()
-        ti1a.refresh_from_db(session=session)
-        assert ti1a.state == State.NONE
-
-        # don't touch ti1b
-        ti1b.refresh_from_db(session=session)
-        assert ti1b.state == State.SUCCESS
-
-        # don't touch ti2
-        ti2.refresh_from_db(session=session)
-        assert ti2.state == State.SCHEDULED
-
     def test_adopt_or_reset_orphaned_tasks(self, dag_maker):
         session = settings.Session()
         with dag_maker('test_execute_helper_reset_orphaned_tasks') as dag:
@@ -1143,60 +988,6 @@ class TestSchedulerJob:
         ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
         assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should not be reset"
 
-    @pytest.mark.parametrize(
-        "initial_task_state, expected_task_state",
-        [
-            [State.UP_FOR_RETRY, State.FAILED],
-            [State.QUEUED, State.NONE],
-            [State.SCHEDULED, State.NONE],
-            [State.UP_FOR_RESCHEDULE, State.NONE],
-        ],
-    )
-    def test_scheduler_loop_should_change_state_for_tis_without_dagrun(
-        self, initial_task_state, expected_task_state, dag_maker
-    ):
-        session = settings.Session()
-        dag_id = 'test_execute_helper_should_change_state_for_tis_without_dagrun'
-        with dag_maker(
-            dag_id,
-            start_date=DEFAULT_DATE + timedelta(days=1),
-        ):
-            op1 = DummyOperator(task_id='op1')
-
-        # Create DAG run with FAILED state
-        dr = dag_maker.create_dagrun(
-            state=State.FAILED,
-            execution_date=DEFAULT_DATE + timedelta(days=1),
-            start_date=DEFAULT_DATE + timedelta(days=1),
-        )
-        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
-        ti.state = initial_task_state
-        session.commit()
-
-        # This poll interval is large, bug the scheduler doesn't sleep that
-        # long, instead we hit the clean_tis_without_dagrun interval instead
-        self.scheduler_job = SchedulerJob(num_runs=2, processor_poll_interval=30)
-        self.scheduler_job.dagbag = dag_maker.dagbag
-        executor = MockExecutor(do_update=False)
-        executor.queued_tasks
-        self.scheduler_job.executor = executor
-        processor = mock.MagicMock()
-        processor.done = False
-        self.scheduler_job.processor_agent = processor
-
-        with mock.patch.object(settings, "USE_JOB_SCHEDULE", False), conf_vars(
-            {('scheduler', 'clean_tis_without_dagrun_interval'): '0.001'}
-        ):
-            self.scheduler_job._run_scheduler_loop()
-
-        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
-        assert ti.state == expected_task_state
-        assert ti.start_date is not None
-        if expected_task_state in State.finished:
-            assert ti.end_date is not None
-            assert ti.start_date == ti.end_date
-            assert ti.duration is not None
-
     @mock.patch('airflow.jobs.scheduler_job.DagFileProcessorAgent')
     def test_executor_end_called(self, mock_processor_agent):
         """
@@ -1289,7 +1080,7 @@ class TestSchedulerJob:
             full_filepath=dr.dag.fileloc,
             dag_id=dr.dag_id,
             is_failure_callback=True,
-            execution_date=dr.execution_date,
+            run_id=dr.run_id,
             msg="timed_out",
         )
 
@@ -1330,7 +1121,7 @@ class TestSchedulerJob:
             full_filepath=dr.dag.fileloc,
             dag_id=dr.dag_id,
             is_failure_callback=True,
-            execution_date=dr.execution_date,
+            run_id=dr.run_id,
             msg="timed_out",
         )
 
@@ -1374,7 +1165,7 @@ class TestSchedulerJob:
             full_filepath=dag.fileloc,
             dag_id=dr.dag_id,
             is_failure_callback=bool(state == State.FAILED),
-            execution_date=dr.execution_date,
+            run_id=dr.run_id,
             msg=expected_callback_msg,
         )
 
@@ -1531,7 +1322,7 @@ class TestSchedulerJob:
         for tid, state in expected_task_states.items():
             if state != State.FAILED:
                 continue
-            self.null_exec.mock_task_fail(dag_id, tid, ex_date)
+            self.null_exec.mock_task_fail(dag_id, tid, dr.run_id)
 
         try:
             dag = DagBag().get_dag(dag.dag_id)
@@ -1541,13 +1332,6 @@ class TestSchedulerJob:
         except AirflowException:
             pass
 
-        # test tasks
-        for task_id, expected_state in expected_task_states.items():
-            task = dag.get_task(task_id)
-            ti = TaskInstance(task, ex_date)
-            ti.refresh_from_db()
-            assert ti.state == expected_state
-
         # load dagrun
         dr = DagRun.find(dag_id=dag_id, execution_date=ex_date)
         dr = dr[0]
@@ -1555,6 +1339,11 @@ class TestSchedulerJob:
 
         assert dr.state == dagrun_state
 
+        # test tasks
+        for task_id, expected_state in expected_task_states.items():
+            ti = dr.get_task_instance(task_id)
+            assert ti.state == expected_state
+
     def test_dagrun_fail(self):
         """
         DagRuns with one failed and one incomplete root task -> FAILED
@@ -1607,7 +1396,7 @@ class TestSchedulerJob:
             execution_date=DEFAULT_DATE,
             state=State.RUNNING,
         )
-        self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', DEFAULT_DATE)
+        self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', dr.run_id)
 
         with pytest.raises(AirflowException):
             dag.run(start_date=dr.execution_date, end_date=dr.execution_date, executor=self.null_exec)
@@ -1710,7 +1499,10 @@ class TestSchedulerJob:
             # one task ran
             assert len(session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all()) == 1
             assert [
-                (TaskInstanceKey(dag.dag_id, 'dummy', DEFAULT_DATE, 1), (State.SUCCESS, None)),
+                (
+                    TaskInstanceKey(dag.dag_id, 'dummy', f'backfill__{DEFAULT_DATE.isoformat()}', 1),
+                    (State.SUCCESS, None),
+                ),
             ] == bf_exec.sorted_tasks
             session.commit()
 
@@ -1835,13 +1627,14 @@ class TestSchedulerJob:
 
         assert len(task_instances_list) == 1
 
-    def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker):
+    @pytest.mark.need_serialized_dag
+    def test_scheduler_verify_pool_full_2_slots_per_task(self, dag_maker, session):
         """
         Test task instances not queued when pool is full.
 
         Variation with non-default pool_slots
         """
-        with dag_maker(dag_id='test_scheduler_verify_pool_full_2_slots_per_task') as dag:
+        with dag_maker(dag_id='test_scheduler_verify_pool_full_2_slots_per_task', session=session) as dag:
             BashOperator(
                 task_id='dummy',
                 pool='test_scheduler_verify_pool_full_2_slots_per_task',
@@ -1849,7 +1642,6 @@ class TestSchedulerJob:
                 bash_command='echo hi',
             )
 
-        session = settings.Session()
         pool = Pool(pool='test_scheduler_verify_pool_full_2_slots_per_task', slots=6)
         session.add(pool)
         session.flush()
@@ -1865,6 +1657,7 @@ class TestSchedulerJob:
                 run_type=DagRunType.SCHEDULED,
                 execution_date=date,
                 state=State.RUNNING,
+                session=session,
             )
             self.scheduler_job._schedule_dag_run(dr, session)
 
@@ -1979,7 +1772,10 @@ class TestSchedulerJob:
         self.scheduler_job.processor_agent = mock.MagicMock()
 
         dr = dag_maker.create_dagrun()
-        self.scheduler_job._schedule_dag_run(dr, session)
+        for ti in dr.task_instances:
+            ti.state = State.SCHEDULED
+            session.merge(ti)
+        session.flush()
 
         task_instances_list = self.scheduler_job._executable_task_instances_to_queued(
             max_tis=32, session=session
@@ -2358,25 +2154,6 @@ class TestSchedulerJob:
         assert 0 == self.scheduler_job.adopt_or_reset_orphaned_tasks(session=session)
         session.rollback()
 
-    def test_reset_orphaned_tasks_nonexistent_dagrun(self, dag_maker):
-        """Make sure a task in an orphaned state is not reset if it has no dagrun."""
-        dag_id = 'test_reset_orphaned_tasks_nonexistent_dagrun'
-        with dag_maker(dag_id=dag_id, schedule_interval='@daily'):
-            task_id = dag_id + '_task'
-            task = DummyOperator(task_id=task_id)
-
-        self.scheduler_job = SchedulerJob(subdir=os.devnull)
-        session = settings.Session()
-
-        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
-        ti.refresh_from_db()
-        ti.state = State.SCHEDULED
-        session.merge(ti)
-        session.flush()
-
-        assert 0 == self.scheduler_job.adopt_or_reset_orphaned_tasks(session=session)
-        session.rollback()
-
     def test_reset_orphaned_tasks_no_orphans(self, dag_maker):
         dag_id = 'test_reset_orphaned_tasks_no_orphans'
         with dag_maker(dag_id=dag_id, schedule_interval='@daily'):
@@ -2770,58 +2547,52 @@ class TestSchedulerJob:
             session=session,
         )
 
-        dag.sync_to_db(session=session)
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         self.scheduler_job.executor = MockExecutor()
         self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
 
         self.scheduler_job._do_scheduling(session)
-        session.add(run1)
+        run1 = session.merge(run1)
         session.refresh(run1)
         assert run1.state == State.FAILED
         assert run1_ti.state == State.SKIPPED
 
         # Run scheduling again to assert run2 has started
         self.scheduler_job._do_scheduling(session)
-        session.add(run2)
+        run2 = session.merge(run2)
         session.refresh(run2)
         assert run2.state == State.RUNNING
         run2_ti = run2.get_task_instance(task1.task_id, session)
         assert run2_ti.state == State.QUEUED
 
-    def test_do_schedule_max_active_runs_task_removed(self, dag_maker):
+    def test_do_schedule_max_active_runs_task_removed(self, session, dag_maker):
         """Test that tasks in removed state don't count as actively running."""
-
         with dag_maker(
             dag_id='test_do_schedule_max_active_runs_task_removed',
             start_date=DEFAULT_DATE,
             schedule_interval='@once',
             max_active_runs=1,
-        ) as dag:
+            session=session,
+        ):
             # Can't use DummyOperator as that goes straight to success
-            task1 = BashOperator(task_id='dummy1', bash_command='true')
-
-        session = settings.Session()
-        session.add(TaskInstance(task1, DEFAULT_DATE, State.REMOVED))
-        session.flush()
+            BashOperator(task_id='dummy1', bash_command='true')
 
-        run1 = dag.create_dagrun(
+        run1 = dag_maker.create_dagrun(
             run_type=DagRunType.SCHEDULED,
             execution_date=DEFAULT_DATE + timedelta(hours=1),
             state=State.RUNNING,
-            session=session,
         )
 
-        dag.sync_to_db(session=session)  # Update the date fields
-
         self.scheduler_job = SchedulerJob(subdir=os.devnull)
         self.scheduler_job.executor = MockExecutor(do_update=False)
         self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent)
 
         num_queued = self.scheduler_job._do_scheduling(session)
-
         assert num_queued == 1
-        ti = run1.get_task_instance(task1.task_id, session)
+
+        session.flush()
+        ti = run1.task_instances[0]
+        ti.refresh_from_db(session=session)
         assert ti.state == State.QUEUED
 
     def test_do_schedule_max_active_runs_and_manual_trigger(self, dag_maker):
@@ -2858,7 +2629,7 @@ class TestSchedulerJob:
         num_queued = self.scheduler_job._do_scheduling(session)
         # Add it back in to the session so we can refresh it. (_do_scheduling does an expunge_all to reduce
         # memory)
-        session.add(dag_run)
+        dag_run = session.merge(dag_run)
         session.refresh(dag_run)
 
         assert num_queued == 2
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 6718450..5adc91f 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -22,10 +22,8 @@ import time
 
 import pytest
 
-from airflow import DAG
 from airflow.jobs.triggerer_job import TriggererJob
 from airflow.models import Trigger
-from airflow.models.taskinstance import TaskInstance
 from airflow.operators.dummy import DummyOperator
 from airflow.triggers.base import TriggerEvent
 from airflow.triggers.temporal import TimeDeltaTrigger
@@ -293,7 +291,7 @@ def test_trigger_cleanup(session):
 
 
 @pytest.mark.skipif(sys.version_info.minor <= 6 and sys.version_info.major <= 3, reason="No triggerer on 3.6")
-def test_invalid_trigger(session):
+def test_invalid_trigger(session, dag_maker):
     """
     Checks that the triggerer will correctly fail task instances that depend on
     triggers that can't even be loaded.
@@ -305,22 +303,14 @@ def test_invalid_trigger(session):
     session.commit()
 
     # Create the test DAG and task
-    with DAG(
-        dag_id='test_invalid_trigger',
-        start_date=timezone.datetime(2016, 1, 1),
-        schedule_interval='@once',
-        max_active_runs=1,
-    ):
-        task1 = DummyOperator(task_id='dummy1')
+    with dag_maker(dag_id='test_invalid_trigger', session=session):
+        DummyOperator(task_id='dummy1')
 
+    dr = dag_maker.create_dagrun()
+    task_instance = dr.task_instances[0]
     # Make a task instance based on that and tie it to the trigger
-    task_instance = TaskInstance(
-        task1,
-        execution_date=timezone.datetime(2016, 1, 1),
-        state=TaskInstanceState.DEFERRED,
-    )
+    task_instance.state = TaskInstanceState.DEFERRED
     task_instance.trigger_id = 1
-    session.add(task_instance)
     session.commit()
 
     # Make a TriggererJob and have it retrieve DB tasks
diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py
index b5ebbea..171fad1 100644
--- a/tests/lineage/test_lineage.py
+++ b/tests/lineage/test_lineage.py
@@ -15,15 +15,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import unittest
 from unittest import mock
 
 from airflow.lineage import AUTO, apply_lineage, get_backend, prepare_lineage
 from airflow.lineage.backend import LineageBackend
 from airflow.lineage.entities import File
-from airflow.models import DAG, TaskInstance as TI
+from airflow.models import TaskInstance as TI
 from airflow.operators.dummy import DummyOperator
 from airflow.utils import timezone
+from airflow.utils.types import DagRunType
 from tests.test_utils.config import conf_vars
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -34,10 +34,8 @@ class CustomLineageBackend(LineageBackend):
         pass
 
 
-class TestLineage(unittest.TestCase):
-    def test_lineage(self):
-        dag = DAG(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE)
-
+class TestLineage:
+    def test_lineage(self, dag_maker):
         f1s = "/tmp/does_not_exist_1-{}"
         f2s = "/tmp/does_not_exist_2-{}"
         f3s = "/tmp/does_not_exist_3"
@@ -45,7 +43,7 @@ class TestLineage(unittest.TestCase):
         file2 = File(f2s.format("{{ execution_date }}"))
         file3 = File(f3s)
 
-        with dag:
+        with dag_maker(dag_id='test_prepare_lineage', start_date=DEFAULT_DATE) as dag:
             op1 = DummyOperator(
                 task_id='leave1',
                 inlets=file1,
@@ -64,12 +62,13 @@ class TestLineage(unittest.TestCase):
             op4.set_downstream(op5)
 
         dag.clear()
+        dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
 
         # execution_date is set in the context in order to avoid creating task instances
-        ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
-        ctx2 = {"ti": TI(task=op2, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
-        ctx3 = {"ti": TI(task=op3, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
-        ctx5 = {"ti": TI(task=op5, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+        ctx1 = {"ti": TI(task=op1, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE}
+        ctx2 = {"ti": TI(task=op2, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE}
+        ctx3 = {"ti": TI(task=op3, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE}
+        ctx5 = {"ti": TI(task=op5, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE}
 
         # prepare with manual inlets and outlets
         op1.pre_execute(ctx1)
@@ -99,13 +98,12 @@ class TestLineage(unittest.TestCase):
         assert len(op5.inlets) == 2
         op5.post_execute(ctx5)
 
-    def test_lineage_render(self):
+    def test_lineage_render(self, dag_maker):
         # tests inlets / outlets are rendered if they are added
         # after initialization
-        dag = DAG(dag_id='test_lineage_render', start_date=DEFAULT_DATE)
-
-        with dag:
+        with dag_maker(dag_id='test_lineage_render', start_date=DEFAULT_DATE):
             op1 = DummyOperator(task_id='task1')
+        dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
 
         f1s = "/tmp/does_not_exist_1-{}"
         file1 = File(f1s.format("{{ execution_date }}"))
@@ -114,14 +112,14 @@ class TestLineage(unittest.TestCase):
         op1.outlets.append(file1)
 
         # execution_date is set in the context in order to avoid creating task instances
-        ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+        ctx1 = {"ti": TI(task=op1, run_id=dag_run.run_id), "execution_date": DEFAULT_DATE}
 
         op1.pre_execute(ctx1)
         assert op1.inlets[0].url == f1s.format(DEFAULT_DATE)
         assert op1.outlets[0].url == f1s.format(DEFAULT_DATE)
 
     @mock.patch("airflow.lineage.get_backend")
-    def test_lineage_is_sent_to_backend(self, mock_get_backend):
+    def test_lineage_is_sent_to_backend(self, mock_get_backend, dag_maker):
         class TestBackend(LineageBackend):
             def send_lineage(self, operator, inlets=None, outlets=None, context=None):
                 assert len(inlets) == 1
@@ -132,17 +130,17 @@ class TestLineage(unittest.TestCase):
 
         mock_get_backend.return_value = TestBackend()
 
-        dag = DAG(dag_id='test_lineage_is_sent_to_backend', start_date=DEFAULT_DATE)
-
-        with dag:
+        with dag_maker(dag_id='test_lineage_is_sent_to_backend', start_date=DEFAULT_DATE):
             op1 = DummyOperator(task_id='task1')
+        dag_run = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
 
         file1 = File("/tmp/some_file")
 
         op1.inlets.append(file1)
         op1.outlets.append(file1)
 
-        ctx1 = {"ti": TI(task=op1, execution_date=DEFAULT_DATE), "execution_date": DEFAULT_DATE}
+        (ti,) = dag_run.task_instances
+        ctx1 = {"ti": ti, "execution_date": DEFAULT_DATE}
 
         prep = prepare_lineage(func)
         prep(op1, ctx1)
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index bf3ea4a..a9662a9 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -565,6 +565,16 @@ class TestBaseOperatorMethods(unittest.TestCase):
         op_copy.post_execute({})
         assert called
 
+    def test_task_naive_datetime(self):
+        naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
+
+        op_no_dag = DummyOperator(
+            task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime
+        )
+
+        assert op_no_dag.start_date.tzinfo
+        assert op_no_dag.end_date.tzinfo
+
 
 class CustomOp(DummyOperator):
     template_fields = ("field", "field2")
diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py
index 4f64347..e883bcc 100644
--- a/tests/models/test_cleartasks.py
+++ b/tests/models/test_cleartasks.py
@@ -17,44 +17,45 @@
 # under the License.
 
 import datetime
-import unittest
 
-from parameterized import parameterized
+import pytest
 
 from airflow import settings
 from airflow.models import DAG, TaskInstance as TI, TaskReschedule, clear_task_instances
 from airflow.operators.dummy import DummyOperator
 from airflow.sensors.python import PythonSensor
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE
 from tests.test_utils import db
 
 
-class TestClearTasks(unittest.TestCase):
-    def setUp(self) -> None:
+class TestClearTasks:
+    @pytest.fixture(autouse=True, scope="class")
+    def clean(self):
         db.clear_db_runs()
 
-    def tearDown(self):
+        yield
+
         db.clear_db_runs()
 
-    def test_clear_task_instances(self):
-        dag = DAG(
+    def test_clear_task_instances(self, dag_maker):
+        with dag_maker(
             'test_clear_task_instances',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        task0 = DummyOperator(task_id='0', owner='test', dag=dag)
-        task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+        ) as dag:
+            task0 = DummyOperator(task_id='0')
+            task1 = DummyOperator(task_id='1', retries=2)
 
-        dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
+        ti0, ti1 = dr.task_instances
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
 
         ti0.run()
         ti1.run()
@@ -66,19 +67,22 @@ class TestClearTasks(unittest.TestCase):
         ti0.refresh_from_db()
         ti1.refresh_from_db()
         # Next try to run will be try 2
+        assert ti0.state is None
         assert ti0.try_number == 2
         assert ti0.max_tries == 1
+        assert ti1.state is None
         assert ti1.try_number == 2
         assert ti1.max_tries == 3
 
-    def test_clear_task_instances_external_executor_id(self):
-        dag = DAG(
+    def test_clear_task_instances_external_executor_id(self, dag_maker):
+        with dag_maker(
             'test_clear_task_instances_external_executor_id',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        task0 = DummyOperator(task_id='task0', owner='test', dag=dag)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
+        ) as dag:
+            DummyOperator(task_id='task0')
+
+        ti0 = dag_maker.create_dagrun().task_instances[0]
         ti0.state = State.SUCCESS
         ti0.external_executor_id = "some_external_executor_id"
 
@@ -94,58 +98,60 @@ class TestClearTasks(unittest.TestCase):
             assert ti0.state is None
             assert ti0.external_executor_id is None
 
-    @parameterized.expand([(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)])
-    def test_clear_task_instances_dr_state(self, state, last_scheduling):
+    @pytest.mark.parametrize(
+        ["state", "last_scheduling"], [(State.QUEUED, None), (State.RUNNING, DEFAULT_DATE)]
+    )
+    def test_clear_task_instances_dr_state(self, state, last_scheduling, dag_maker):
         """Test that DR state is set to None after clear.
         And that DR.last_scheduling_decision is handled OK.
         start_date is also set to None
         """
-        dag = DAG(
+        with dag_maker(
             'test_clear_task_instances',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        task0 = DummyOperator(task_id='0', owner='test', dag=dag)
-        task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
-        session = settings.Session()
-        dr = dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        ) as dag:
+            DummyOperator(task_id='0')
+            DummyOperator(task_id='1', retries=2)
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
+        ti0, ti1 = dr.task_instances
         dr.last_scheduling_decision = DEFAULT_DATE
-        session.add(dr)
-        session.commit()
+        ti0.state = TaskInstanceState.SUCCESS
+        ti1.state = TaskInstanceState.SUCCESS
+        session = dag_maker.session
+        session.flush()
 
-        ti0.run()
-        ti1.run()
         qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
         clear_task_instances(qry, session, dag_run_state=state, dag=dag)
+        session.flush()
+
+        session.refresh(dr)
 
-        dr = ti0.get_dagrun()
         assert dr.state == state
         assert dr.start_date is None
         assert dr.last_scheduling_decision == last_scheduling
 
-    def test_clear_task_instances_without_task(self):
-        dag = DAG(
+    def test_clear_task_instances_without_task(self, dag_maker):
+        with dag_maker(
             'test_clear_task_instances_without_task',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        task0 = DummyOperator(task_id='task0', owner='test', dag=dag)
-        task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+        ) as dag:
+            task0 = DummyOperator(task_id='task0')
+            task1 = DummyOperator(task_id='task1', retries=2)
 
-        dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
 
+        ti0, ti1 = dr.task_instances
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
+
         ti0.run()
         ti1.run()
 
@@ -167,23 +173,24 @@ class TestClearTasks(unittest.TestCase):
         assert ti1.try_number == 2
         assert ti1.max_tries == 2
 
-    def test_clear_task_instances_without_dag(self):
-        dag = DAG(
+    def test_clear_task_instances_without_dag(self, dag_maker):
+        with dag_maker(
             'test_clear_task_instances_without_dag',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        task0 = DummyOperator(task_id='task_0', owner='test', dag=dag)
-        task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
+        ) as dag:
+            task0 = DummyOperator(task_id='task0')
+            task1 = DummyOperator(task_id='task1', retries=2)
 
-        dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
 
+        ti0, ti1 = dr.task_instances
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
+
         ti0.run()
         ti1.run()
 
@@ -200,10 +207,10 @@ class TestClearTasks(unittest.TestCase):
         assert ti1.try_number == 2
         assert ti1.max_tries == 2
 
-    def test_clear_task_instances_with_task_reschedule(self):
+    def test_clear_task_instances_with_task_reschedule(self, dag_maker):
         """Test that TaskReschedules are deleted correctly when TaskInstances are cleared"""
 
-        with DAG(
+        with dag_maker(
             'test_clear_task_instances_with_task_reschedule',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
@@ -211,15 +218,14 @@ class TestClearTasks(unittest.TestCase):
             task0 = PythonSensor(task_id='0', python_callable=lambda: False, mode="reschedule")
             task1 = PythonSensor(task_id='1', python_callable=lambda: False, mode="reschedule")
 
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
-
-        dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
 
+        ti0, ti1 = dr.task_instances
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
         ti0.run()
         ti1.run()
 
@@ -231,7 +237,7 @@ class TestClearTasks(unittest.TestCase):
                     .filter(
                         TaskReschedule.dag_id == dag.dag_id,
                         TaskReschedule.task_id == task_id,
-                        TaskReschedule.execution_date == DEFAULT_DATE,
+                        TaskReschedule.run_id == dr.run_id,
                         TaskReschedule.try_number == 1,
                     )
                     .count()
@@ -244,22 +250,27 @@ class TestClearTasks(unittest.TestCase):
             assert count_task_reschedule(ti0.task_id) == 0
             assert count_task_reschedule(ti1.task_id) == 1
 
-    def test_dag_clear(self):
-        dag = DAG(
+    def test_dag_clear(self, dag_maker):
+        with dag_maker(
             'test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)
-        )
-        task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag)
-        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
+        ) as dag:
+            task0 = DummyOperator(task_id='test_dag_clear_task_0')
+            task1 = DummyOperator(task_id='test_dag_clear_task_1', retries=2)
 
-        dag.create_dagrun(
-            execution_date=ti0.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
+        session = dag_maker.session
+
+        ti0, ti1 = dr.task_instances
+        ti0.refresh_from_task(task0)
+        ti1.refresh_from_task(task1)
 
         # Next try to run will be try 1
         assert ti0.try_number == 1
         ti0.run()
+
         assert ti0.try_number == 2
         dag.clear()
         ti0.refresh_from_db()
@@ -267,12 +278,14 @@ class TestClearTasks(unittest.TestCase):
         assert ti0.state == State.NONE
         assert ti0.max_tries == 1
 
-        task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2)
-        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
         assert ti1.max_tries == 2
         ti1.try_number = 1
+        session.merge(ti1)
+        session.commit()
+
         # Next try will be 2
         ti1.run()
+
         assert ti1.try_number == 3
         assert ti1.max_tries == 2
 
@@ -297,16 +310,16 @@ class TestClearTasks(unittest.TestCase):
                 start_date=DEFAULT_DATE,
                 end_date=DEFAULT_DATE + datetime.timedelta(days=10),
             )
-            ti = TI(
-                task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag),
-                execution_date=DEFAULT_DATE,
-            )
+            task = DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag)
 
-            dag.create_dagrun(
-                execution_date=ti.execution_date,
+            dr = dag.create_dagrun(
+                execution_date=DEFAULT_DATE,
                 state=State.RUNNING,
                 run_type=DagRunType.SCHEDULED,
+                session=session,
             )
+            ti = dr.task_instances[0]
+            ti.task = task
             dags.append(dag)
             tis.append(ti)
 
@@ -361,26 +374,25 @@ class TestClearTasks(unittest.TestCase):
                 assert tis[i].try_number == 3
                 assert tis[i].max_tries == 2
 
-    def test_operator_clear(self):
-        dag = DAG(
+    def test_operator_clear(self, dag_maker):
+        with dag_maker(
             'test_operator_clear',
             start_date=DEFAULT_DATE,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        )
-        op1 = DummyOperator(task_id='bash_op', owner='test', dag=dag)
-        op2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1)
-
-        op2.set_upstream(op1)
+        ):
+            op1 = DummyOperator(task_id='bash_op')
+            op2 = DummyOperator(task_id='dummy_op', retries=1)
+            op1 >> op2
 
-        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
-        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
-
-        dag.create_dagrun(
-            execution_date=ti1.execution_date,
+        dr = dag_maker.create_dagrun(
             state=State.RUNNING,
             run_type=DagRunType.SCHEDULED,
         )
 
+        ti1, ti2 = dr.task_instances
+        ti1.task = op1
+        ti2.task = op2
+
         ti2.run()
         # Dependency not met
         assert ti2.try_number == 1
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0ee0baa..c6d54ed 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -36,6 +36,7 @@ import pytest
 from dateutil.relativedelta import relativedelta
 from freezegun import freeze_time
 from parameterized import parameterized
+from sqlalchemy import inspect
 
 from airflow import models, settings
 from airflow.configuration import conf
@@ -66,6 +67,13 @@ from tests.test_utils.timetables import cron_timetable, delta_timetable
 TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)
 
 
+@pytest.fixture
+def session():
+    with create_session() as session:
+        yield session
+        session.rollback()
+
+
 class TestDag(unittest.TestCase):
     def setUp(self) -> None:
         clear_db_runs()
@@ -448,13 +456,24 @@ class TestDag(unittest.TestCase):
         test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
         test_task = DummyOperator(task_id=test_task_id, dag=test_dag)
 
-        ti1 = TI(task=test_task, execution_date=DEFAULT_DATE)
+        dr1 = test_dag.create_dagrun(state=None, run_id="test1", execution_date=DEFAULT_DATE)
+        dr2 = test_dag.create_dagrun(
+            state=None, run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(days=1)
+        )
+        dr3 = test_dag.create_dagrun(
+            state=None, run_id="test3", execution_date=DEFAULT_DATE + datetime.timedelta(days=2)
+        )
+        dr4 = test_dag.create_dagrun(
+            state=None, run_id="test4", execution_date=DEFAULT_DATE + datetime.timedelta(days=3)
+        )
+
+        ti1 = TI(task=test_task, run_id=dr1.run_id)
         ti1.state = None
-        ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
+        ti2 = TI(task=test_task, run_id=dr2.run_id)
         ti2.state = State.RUNNING
-        ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
+        ti3 = TI(task=test_task, run_id=dr3.run_id)
         ti3.state = State.QUEUED
-        ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3))
+        ti4 = TI(task=test_task, run_id=dr4.run_id)
         ti4.state = State.RUNNING
         session = settings.Session()
         session.merge(ti1)
@@ -1094,10 +1113,12 @@ class TestDag(unittest.TestCase):
         when = TEST_DATE
         dag.add_task(BaseOperator(task_id="faketastic", owner='Also fake', start_date=when))
 
-        dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL)
-        # should not raise any exception
-        dag.handle_callback(dag_run, success=False)
-        dag.handle_callback(dag_run, success=True)
+        with create_session() as session:
+            dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL, session=session)
+
+            # should not raise any exception
+            dag.handle_callback(dag_run, success=False)
+            dag.handle_callback(dag_run, success=True)
 
         mock_stats.incr.assert_called_with("dag.callback_exceptions")
 
@@ -1970,11 +1991,11 @@ class TestDagDecorator(unittest.TestCase):
         assert dag.params['value'] == self.VALUE
 
 
-def test_set_task_instance_state():
+def test_set_task_instance_state(session, dag_maker):
     """Test that set_task_instance_state updates the TaskInstance state and clear downstream failed"""
 
     start_date = datetime_tz(2020, 1, 1)
-    with DAG("test_set_task_instance_state", start_date=start_date) as dag:
+    with dag_maker("test_set_task_instance_state", start_date=start_date, session=session) as dag:
         task_1 = DummyOperator(task_id="task_1")
         task_2 = DummyOperator(task_id="task_2")
         task_3 = DummyOperator(task_id="task_3")
@@ -1983,49 +2004,48 @@ def test_set_task_instance_state():
 
         task_1 >> [task_2, task_3, task_4, task_5]
 
-    dagrun = dag.create_dagrun(
-        start_date=start_date, execution_date=start_date, state=State.FAILED, run_type=DagRunType.SCHEDULED
-    )
+    dagrun = dag_maker.create_dagrun(state=State.FAILED, run_type=DagRunType.SCHEDULED)
 
-    def get_task_instance(session, task):
+    def get_ti_from_db(task):
         return (
             session.query(TI)
             .filter(
                 TI.dag_id == dag.dag_id,
                 TI.task_id == task.task_id,
-                TI.execution_date == start_date,
+                TI.run_id == dagrun.run_id,
             )
             .one()
         )
 
-    with create_session() as session:
-        get_task_instance(session, task_1).state = State.FAILED
-        get_task_instance(session, task_2).state = State.SUCCESS
-        get_task_instance(session, task_3).state = State.UPSTREAM_FAILED
-        get_task_instance(session, task_4).state = State.FAILED
-        get_task_instance(session, task_5).state = State.SKIPPED
+    get_ti_from_db(task_1).state = State.FAILED
+    get_ti_from_db(task_2).state = State.SUCCESS
+    get_ti_from_db(task_3).state = State.UPSTREAM_FAILED
+    get_ti_from_db(task_4).state = State.FAILED
+    get_ti_from_db(task_5).state = State.SKIPPED
 
-        session.commit()
+    session.flush()
 
     altered = dag.set_task_instance_state(
-        task_id=task_1.task_id, execution_date=start_date, state=State.SUCCESS
+        task_id=task_1.task_id, execution_date=start_date, state=State.SUCCESS, session=session
     )
 
-    with create_session() as session:
-        # After _mark_task_instance_state, task_1 is marked as SUCCESS
-        assert get_task_instance(session, task_1).state == State.SUCCESS
-        # task_2 remains as SUCCESS
-        assert get_task_instance(session, task_2).state == State.SUCCESS
-        # task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state
-        assert get_task_instance(session, task_3).state == State.NONE
-        assert get_task_instance(session, task_4).state == State.NONE
-        # task_5 remains as SKIPPED
-        assert get_task_instance(session, task_5).state == State.SKIPPED
-        dagrun.refresh_from_db(session=session)
-        # dagrun should be set to QUEUED
-        assert dagrun.get_state() == State.QUEUED
-
-    assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', start_date, 1)}
+    # After _mark_task_instance_state, task_1 is marked as SUCCESS
+    ti1 = get_ti_from_db(task_1)
+    assert ti1.state == State.SUCCESS
+    # TIs should have DagRun pre-loaded
+    assert isinstance(inspect(ti1).attrs.dag_run.loaded_value, DagRun)
+    # task_2 remains as SUCCESS
+    assert get_ti_from_db(task_2).state == State.SUCCESS
+    # task_3 and task_4 are cleared because they were in FAILED/UPSTREAM_FAILED state
+    assert get_ti_from_db(task_3).state == State.NONE
+    assert get_ti_from_db(task_4).state == State.NONE
+    # task_5 remains as SKIPPED
+    assert get_ti_from_db(task_5).state == State.SKIPPED
+    dagrun.refresh_from_db(session=session)
+    # dagrun should be set to QUEUED
+    assert dagrun.get_state() == State.QUEUED
+
+    assert {t.key for t in altered} == {('test_set_task_instance_state', 'task_1', dagrun.run_id, 1)}
 
 
 @pytest.mark.parametrize(
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index fb6a099..c9b9bca 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -24,10 +24,8 @@ from unittest.mock import call
 from parameterized import parameterized
 
 from airflow import models, settings
-from airflow.jobs.base_job import BaseJob
 from airflow.models import DAG, DagBag, DagModel, TaskInstance as TI, clear_task_instances
 from airflow.models.dagrun import DagRun
-from airflow.operators.bash import BashOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import ShortCircuitOperator
 from airflow.serialization.serialized_objects import SerializedDAG
@@ -39,7 +37,7 @@ from airflow.utils.state import State
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE
-from tests.test_utils.db import clear_db_dags, clear_db_jobs, clear_db_pools, clear_db_runs
+from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs
 
 
 class TestDagRun(unittest.TestCase):
@@ -87,8 +85,7 @@ class TestDagRun(unittest.TestCase):
             for task_id, task_state in task_states.items():
                 ti = dag_run.get_task_instance(task_id)
                 ti.set_state(task_state, session)
-            session.commit()
-            session.close()
+            session.flush()
 
         return dag_run
 
@@ -377,7 +374,7 @@ class TestDagRun(unittest.TestCase):
         assert callback == DagCallbackRequest(
             full_filepath=dag_run.dag.fileloc,
             dag_id="test_dagrun_update_state_with_handle_callback_success",
-            execution_date=dag_run.execution_date,
+            run_id=dag_run.run_id,
             is_failure_callback=False,
             msg="success",
         )
@@ -412,7 +409,7 @@ class TestDagRun(unittest.TestCase):
         assert callback == DagCallbackRequest(
             full_filepath=dag_run.dag.fileloc,
             dag_id="test_dagrun_update_state_with_handle_callback_failure",
-            execution_date=dag_run.execution_date,
+            run_id=dag_run.run_id,
             is_failure_callback=True,
             msg="task_failure",
         )
@@ -825,40 +822,3 @@ class TestDagRun(unittest.TestCase):
         ti_failed = dag_run.get_task_instance(dag_task_failed.task_id)
         assert ti_success.state in State.success_states
         assert ti_failed.state in State.failed_states
-
-    def test_delete_dag_run_and_task_instance_does_not_raise_error(self):
-        clear_db_jobs()
-        clear_db_runs()
-
-        job_id = 22
-        dag = DAG(dag_id='test_delete_dag_run', start_date=days_ago(1))
-        _ = BashOperator(task_id='task1', dag=dag, bash_command="echo hi")
-
-        # Simulate DagRun is created by a job inherited by BaseJob with an id
-        # This is so that same foreign key exists on DagRun.creating_job_id & BaseJob.id
-        dag_run = self.create_dag_run(dag=dag, creating_job_id=job_id)
-        assert dag_run is not None
-
-        session = settings.Session()
-
-        job = BaseJob(id=job_id)
-        session.add(job)
-
-        # Simulate TaskInstance is created by a job inherited by BaseJob with an id
-        # This is so that same foreign key exists on TaskInstance.queued_by_job_id & BaseJob.id
-        ti1 = dag_run.get_task_instance(task_id="task1")
-        ti1.queued_by_job_id = job_id
-        session.merge(ti1)
-        session.commit()
-
-        # Test Deleting DagRun does not raise an error
-        session.delete(dag_run)
-
-        # Test Deleting TaskInstance does not raise an error
-        ti1 = dag_run.get_task_instance(task_id="task1")
-        session.delete(ti1)
-        session.commit()
-
-        # CleanUp
-        clear_db_runs()
-        clear_db_jobs()
diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py
index e2093f2..47e4956 100644
--- a/tests/models/test_renderedtifields.py
+++ b/tests/models/test_renderedtifields.py
@@ -23,12 +23,12 @@ from datetime import date, timedelta
 from unittest import mock
 
 import pytest
+from sqlalchemy.orm.session import make_transient
 
 from airflow import settings
 from airflow.configuration import TEST_DAGS_FOLDER
 from airflow.models import Variable
 from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
-from airflow.models.taskinstance import TaskInstance as TI
 from airflow.operators.bash import BashOperator
 from airflow.utils.session import create_session
 from airflow.utils.timezone import datetime
@@ -115,27 +115,32 @@ class TestRenderedTaskInstanceFields:
         Test that template_fields are rendered correctly, stored in the Database,
         and are correctly fetched using RTIF.get_templated_fields
         """
-        with dag_maker("test_serialized_rendered_fields") as dag:
+        with dag_maker("test_serialized_rendered_fields"):
             task = BashOperator(task_id="test", bash_command=templated_field)
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=EXECUTION_DATE)
+            task_2 = BashOperator(task_id="test2", bash_command=templated_field)
+        dr = dag_maker.create_dagrun()
+
+        session = dag_maker.session
+
+        ti, ti2 = dr.task_instances
+        ti.task = task
+        ti2.task = task_2
         rtif = RTIF(ti=ti)
+
         assert ti.dag_id == rtif.dag_id
         assert ti.task_id == rtif.task_id
         assert ti.execution_date == rtif.execution_date
         assert expected_rendered_field == rtif.rendered_fields.get("bash_command")
 
-        with create_session() as session:
-            session.add(rtif)
-
-        assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(ti=ti)
+        session.add(rtif)
+        session.flush()
 
+        assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(
+            ti=ti, session=session
+        )
         # Test the else part of get_templated_fields
         # i.e. for the TIs that are not stored in RTIF table
         # Fetching them will return None
-        task_2 = BashOperator(task_id="test2", bash_command=templated_field, dag=dag)
-
-        ti2 = TI(task_2, EXECUTION_DATE)
         assert RTIF.get_templated_fields(ti=ti2) is None
 
     @pytest.mark.parametrize(
@@ -159,14 +164,15 @@ class TestRenderedTaskInstanceFields:
         session = settings.Session()
         with dag_maker("test_delete_old_records") as dag:
             task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
-        dag_maker.create_dagrun()
-        rtif_list = [
-            RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
-            for num in range(rtif_num)
-        ]
+        rtif_list = []
+        for num in range(rtif_num):
+            dr = dag_maker.create_dagrun(run_id=str(num), execution_date=dag.start_date + timedelta(days=num))
+            ti = dr.task_instances[0]
+            ti.task = task
+            rtif_list.append(RTIF(ti))
 
         session.add_all(rtif_list)
-        session.commit()
+        session.flush()
 
         result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
 
@@ -201,7 +207,11 @@ class TestRenderedTaskInstanceFields:
         with dag_maker("test_write"):
             task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
 
-        rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.task = task
+
+        rtif = RTIF(ti)
         rtif.write()
         result = (
             session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
@@ -220,8 +230,10 @@ class TestRenderedTaskInstanceFields:
         self.clean_db()
         with dag_maker("test_write"):
             updated_task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
-        dag_maker.create_dagrun()
-        rtif_updated = RTIF(TI(task=updated_task, execution_date=EXECUTION_DATE))
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.task = updated_task
+        rtif_updated = RTIF(ti)
         rtif_updated.write()
 
         result_updated = (
@@ -248,10 +260,11 @@ class TestRenderedTaskInstanceFields:
         """
         with dag_maker("test_get_k8s_pod_yaml") as dag:
             task = BashOperator(task_id="test", bash_command="echo hi")
-        dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun()
         dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py'
 
-        ti = TI(task=task, execution_date=EXECUTION_DATE)
+        ti = dr.task_instances[0]
+        ti.task = task
 
         render_k8s_pod_yaml = mock.patch.object(
             ti, 'render_k8s_pod_yaml', return_value={"I'm a": "pod"}
@@ -274,6 +287,8 @@ class TestRenderedTaskInstanceFields:
             session.flush()
 
             assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session)
+            make_transient(ti)
+            # "Delete" it from the DB
             session.rollback()
 
             # Test the else part of get_k8s_pod_yaml
@@ -290,13 +305,14 @@ class TestRenderedTaskInstanceFields:
                 bash_command="echo {{ var.value.api_key }}",
                 env={'foo': 'secret', 'other_api_key': 'masked based on key name'},
             )
-        dag_maker.create_dagrun()
+        dr = dag_maker.create_dagrun()
         redact.side_effect = [
             'val 1',
             'val 2',
         ]
 
-        ti = TI(task=task, execution_date=EXECUTION_DATE)
+        ti = dr.task_instances[0]
+        ti.task = task
         rtif = RTIF(ti=ti)
         assert rtif.rendered_fields == {
             'bash_command': 'val 1',
diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py
index 8734cc7..21ddf4d 100644
--- a/tests/models/test_skipmixin.py
+++ b/tests/models/test_skipmixin.py
@@ -74,9 +74,10 @@ class TestSkipMixin:
         mock_now.return_value = now
         with dag_maker(
             'dag',
+            session=session,
         ):
             tasks = [DummyOperator(task_id='task')]
-        dag_maker.create_dagrun()
+        dag_maker.create_dagrun(execution_date=now)
         SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session)
 
         session.query(TI).filter(
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index ded8901..a2f2ec5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -29,7 +29,6 @@ from unittest.mock import call, mock_open, patch
 import pendulum
 import pytest
 from freezegun import freeze_time
-from sqlalchemy.orm.session import Session
 
 from airflow import models, settings
 from airflow.exceptions import (
@@ -47,6 +46,7 @@ from airflow.models import (
     TaskInstance as TI,
     TaskReschedule,
     Variable,
+    XCom,
 )
 from airflow.models.taskinstance import load_error_file, set_error_file
 from airflow.operators.bash import BashOperator
@@ -70,7 +70,17 @@ from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER
 from tests.test_utils import db
 from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars
-from tests.test_utils.db import clear_db_connections
+from tests.test_utils.db import clear_db_connections, clear_db_runs
+
+
+@pytest.fixture
+def test_pool():
+    with create_session() as session:
+        test_pool = Pool(pool='test_pool', slots=1)
+        session.add(test_pool)
+        session.flush()
+        yield test_pool
+        session.rollback()
 
 
 class CallbackWrapper:
@@ -112,10 +122,10 @@ class TestTaskInstance:
 
     def setup_method(self):
         self.clean_db()
-        with create_session() as session:
-            test_pool = Pool(pool='test_pool', slots=1)
-            session.add(test_pool)
-            session.commit()
+
+        # We don't want to store any code for (test) dags created in this file
+        with patch.object(settings, "STORE_DAG_CODE", False):
+            yield
 
     def teardown_method(self):
         self.clean_db()
@@ -167,40 +177,6 @@ class TestTaskInstance:
         assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)
         assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)
 
-    def test_timezone_awareness(self, dag_maker):
-        naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
-
-        # check ti without dag (just for bw compat)
-        op_no_dag = DummyOperator(task_id='op_no_dag')
-        ti = TI(task=op_no_dag, execution_date=naive_datetime)
-
-        assert ti.execution_date == DEFAULT_DATE
-
-        # check with dag without localized execution_date
-        with dag_maker('dag'):
-            op1 = DummyOperator(task_id='op_1')
-        dag_maker.create_dagrun()
-        ti = TI(task=op1, execution_date=naive_datetime)
-
-        assert ti.execution_date == DEFAULT_DATE
-
-        # with dag and localized execution_date
-        tzinfo = pendulum.timezone("Europe/Amsterdam")
-        execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
-        utc_date = timezone.convert_to_utc(execution_date)
-        ti = TI(task=op1, execution_date=execution_date)
-        assert ti.execution_date == utc_date
-
-    def test_task_naive_datetime(self):
-        naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
-
-        op_no_dag = DummyOperator(
-            task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime
-        )
-
-        assert op_no_dag.start_date.tzinfo
-        assert op_no_dag.end_date.tzinfo
-
     def test_set_dag(self, dag_maker):
         """
         Test assigning Operators to Dags, including deferred assignment
@@ -271,60 +247,47 @@ class TestTaskInstance:
         assert op2 in op3.downstream_list
 
     @patch.object(DAG, 'get_concurrency_reached')
-    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_dummy_dag):
+    def test_requeue_over_dag_concurrency(self, mock_concurrency_reached, create_task_instance):
         mock_concurrency_reached.return_value = True
 
-        _, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_requeue_over_dag_concurrency',
             task_id='test_requeue_over_dag_concurrency_op',
             max_active_runs=1,
             max_active_tasks=2,
+            dagrun_state=State.QUEUED,
         )
-
-        ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
-        # TI.run() will sync from DB before validating deps.
-        with create_session() as session:
-            session.add(ti)
-            session.commit()
         ti.run()
         assert ti.state == State.NONE
 
-    def test_requeue_over_max_active_tis_per_dag(self, create_dummy_dag):
-        _, task = create_dummy_dag(
+    def test_requeue_over_max_active_tis_per_dag(self, create_task_instance):
+        ti = create_task_instance(
             dag_id='test_requeue_over_max_active_tis_per_dag',
             task_id='test_requeue_over_max_active_tis_per_dag_op',
             max_active_tis_per_dag=0,
             max_active_runs=1,
             max_active_tasks=2,
+            dagrun_state=State.QUEUED,
         )
 
-        ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
-        # TI.run() will sync from DB before validating deps.
-        with create_session() as session:
-            session.add(ti)
-            session.commit()
         ti.run()
         assert ti.state == State.NONE
 
-    def test_requeue_over_pool_concurrency(self, create_dummy_dag):
-        _, task = create_dummy_dag(
+    def test_requeue_over_pool_concurrency(self, create_task_instance, test_pool):
+        ti = create_task_instance(
             dag_id='test_requeue_over_pool_concurrency',
             task_id='test_requeue_over_pool_concurrency_op',
             max_active_tis_per_dag=0,
             max_active_runs=1,
             max_active_tasks=2,
         )
-
-        ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
-        # TI.run() will sync from DB before validating deps.
         with create_session() as session:
-            pool = session.query(Pool).filter(Pool.pool == 'test_pool').one()
-            pool.slots = 0
-            session.add(ti)
-            session.commit()
-        ti.run()
-        assert ti.state == State.NONE
+            test_pool.slots = 0
+            session.flush()
+            ti.run()
+            assert ti.state == State.NONE
 
+    @pytest.mark.usefixtures('test_pool')
     def test_not_requeue_non_requeueable_task_instance(self, dag_maker):
         # Use BaseSensorOperator because sensor got
         # one additional DEP in BaseSensorOperator().deps
@@ -333,8 +296,9 @@ class TestTaskInstance:
                 task_id='test_not_requeue_non_requeueable_task_instance_op',
                 pool='test_pool',
             )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
+        ti.state = State.QUEUED
         with create_session() as session:
             session.add(ti)
             session.commit()
@@ -358,69 +322,61 @@ class TestTaskInstance:
         for (dep_patch, method_patch) in patch_dict.values():
             dep_patch.stop()
 
-    def test_mark_non_runnable_task_as_success(self, create_dummy_dag):
+    def test_mark_non_runnable_task_as_success(self, create_task_instance):
         """
         test that running task with mark_success param update task state
         as SUCCESS without running task despite it fails dependency checks.
         """
         non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop()
-        _, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_mark_non_runnable_task_as_success',
             task_id='test_mark_non_runnable_task_as_success_op',
+            dagrun_state=non_runnable_state,
         )
-        ti = TI(task=task, execution_date=timezone.utcnow(), state=non_runnable_state)
-        # TI.run() will sync from DB before validating deps.
-        with create_session() as session:
-            session.add(ti)
-            session.commit()
         ti.run(mark_success=True)
         assert ti.state == State.SUCCESS
 
-    def test_run_pooling_task(self, create_dummy_dag):
+    @pytest.mark.usefixtures('test_pool')
+    def test_run_pooling_task(self, create_task_instance):
         """
         test that running a task in an existing pool update task state as SUCCESS.
         """
-        _, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_run_pooling_task',
             task_id='test_run_pooling_task_op',
             pool='test_pool',
         )
-        ti = TI(task=task, execution_date=timezone.utcnow())
 
         ti.run()
 
-        db.clear_db_pools()
         assert ti.state == State.SUCCESS
 
+    @pytest.mark.usefixtures('test_pool')
     def test_pool_slots_property(self):
         """
         test that try to create a task with pool_slots less than 1
         """
 
-        def create_task_instance():
+        with pytest.raises(AirflowException):
             dag = models.DAG(dag_id='test_run_pooling_task')
-            task = DummyOperator(
+            DummyOperator(
                 task_id='test_run_pooling_task_op',
                 dag=dag,
                 pool='test_pool',
                 pool_slots=0,
             )
-            return TI(task=task, execution_date=timezone.utcnow())
-
-        with pytest.raises(AirflowException):
-            create_task_instance()
 
     @provide_session
-    def test_ti_updates_with_task(self, create_dummy_dag, session=None):
+    def test_ti_updates_with_task(self, create_task_instance, session=None):
         """
         test that updating the executor_config propagates to the TaskInstance DB
         """
-        dag, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_run_pooling_task',
             task_id='test_run_pooling_task_op',
             executor_config={'foo': 'bar'},
         )
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        dag = ti.task.dag
 
         ti.run(session=session)
         tis = dag.get_task_instances()
@@ -432,24 +388,27 @@ class TestTaskInstance:
             dag=dag,
         )
 
-        ti = TI(task=task2, execution_date=timezone.utcnow())
+        ti2 = TI(task=task2, run_id=ti.run_id)
+        session.add(ti2)
+        session.flush()
 
-        ti.run(session=session)
-        tis = dag.get_task_instances()
-        assert {'bar': 'baz'} == tis[1].executor_config
+        ti2.run(session=session)
+        # Ensure it's reloaded
+        ti2.executor_config = None
+        ti2.refresh_from_db(session)
+        assert {'bar': 'baz'} == ti2.executor_config
         session.rollback()
 
-    def test_run_pooling_task_with_mark_success(self, create_dummy_dag):
+    def test_run_pooling_task_with_mark_success(self, create_task_instance):
         """
         test that running task in an existing pool with mark_success param
         update task state as SUCCESS without running task
         despite it fails dependency checks.
         """
-        _, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_run_pooling_task_with_mark_success',
             task_id='test_run_pooling_task_with_mark_success_op',
         )
-        ti = TI(task=task, execution_date=timezone.utcnow())
 
         ti.run(mark_success=True)
         assert ti.state == State.SUCCESS
@@ -468,7 +427,10 @@ class TestTaskInstance:
                 task_id='test_run_pooling_task_with_skip',
                 python_callable=raise_skip_exception,
             )
-        ti = TI(task=task, execution_date=timezone.utcnow())
+
+        dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        ti = dr.task_instances[0]
+        ti.task = task
         ti.run()
         assert State.SKIPPED == ti.state
 
@@ -489,9 +451,9 @@ class TestTaskInstance:
                 retry_delay=datetime.timedelta(seconds=2),
             )
 
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
-        ti.refresh_from_db()
+        dr = dag_maker.create_dagrun()
+        ti = dr.task_instances[0]
+        ti.task = task
         with pytest.raises(AirflowException):
             ti.run()
         ti.refresh_from_db()
@@ -508,7 +470,6 @@ class TestTaskInstance:
                 retries=1,
                 retry_delay=datetime.timedelta(seconds=3),
             )
-        dag_maker.create_dagrun()
 
         def run_with_error(ti):
             try:
@@ -516,7 +477,8 @@ class TestTaskInstance:
             except AirflowException:
                 pass
 
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
 
         assert ti.try_number == 1
         # first run -- up for retry
@@ -553,7 +515,8 @@ class TestTaskInstance:
             except AirflowException:
                 pass
 
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         assert ti.try_number == 1
 
         # first run -- up for retry
@@ -599,8 +562,8 @@ class TestTaskInstance:
                 retry_exponential_backoff=True,
                 max_retry_delay=max_delay,
             )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = dag_maker.create_dagrun().task_instances[0]
+        ti.task = task
         ti.end_date = pendulum.instance(timezone.utcnow())
 
         date = ti.next_retry_datetime()
@@ -641,8 +604,8 @@ class TestTaskInstance:
                 retry_exponential_backoff=True,
                 max_retry_delay=max_delay,
             )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = dag_maker.create_dagrun().task_instances[0]
+        ti.task = task
         ti.end_date = pendulum.instance(timezone.utcnow())
 
         date = ti.next_retry_datetime()
@@ -673,12 +636,11 @@ class TestTaskInstance:
                 retry_delay=datetime.timedelta(seconds=0),
             )
 
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         assert ti._try_number == 0
         assert ti.try_number == 1
 
-        dag_maker.create_dagrun()
-
         def run_ti_and_assert(
             run_date,
             expected_start_date,
@@ -749,6 +711,7 @@ class TestTaskInstance:
         done, fail = True, False
         run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
 
+    @pytest.mark.usefixtures('test_pool')
     def test_reschedule_handling_clear_reschedules(self, dag_maker):
         """
         Test that task reschedules clearing are handled properly
@@ -772,8 +735,8 @@ class TestTaskInstance:
                 retry_delay=datetime.timedelta(seconds=0),
                 pool='test_pool',
             )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         assert ti._try_number == 0
         assert ti.try_number == 1
 
@@ -829,12 +792,13 @@ class TestTaskInstance:
 
         run_date = task.start_date + datetime.timedelta(days=5)
 
-        dag_maker.create_dagrun(
+        dr = dag_maker.create_dagrun(
             execution_date=run_date,
             run_type=DagRunType.SCHEDULED,
         )
 
-        ti = TI(task, run_date)
+        ti = dr.task_instances[0]
+        ti.task = task
 
         # depends_on_past prevents the run
         task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=False)
@@ -911,15 +875,17 @@ class TestTaskInstance:
         flag_upstream_failed: bool,
         expect_state: State,
         expect_completed: bool,
-        create_dummy_dag,
+        dag_maker,
     ):
-        dag, downstream = create_dummy_dag('test-dag', task_id='downstream', trigger_rule=trigger_rule)
-        for i in range(5):
-            task = DummyOperator(task_id=f'runme_{i}', dag=dag)
-            task.set_downstream(downstream)
+        with dag_maker() as dag:
+            downstream = DummyOperator(task_id="downstream", trigger_rule=trigger_rule)
+            for i in range(5):
+                task = DummyOperator(task_id=f'runme_{i}', dag=dag)
+                task.set_downstream(downstream)
         run_date = task.start_date + datetime.timedelta(days=5)
 
-        ti = TI(downstream, run_date)
+        ti = dag_maker.create_dagrun(execution_date=run_date).get_task_instance(downstream.task_id)
+        ti.task = downstream
         dep_results = TriggerRuleDep()._evaluate_trigger_rule(
             ti=ti,
             successes=successes,
@@ -934,9 +900,8 @@ class TestTaskInstance:
         assert completed == expect_completed
         assert ti.state == expect_state
 
-    def test_respects_prev_dagrun_dep(self, create_dummy_dag):
-        _, task = create_dummy_dag(dag_id='test_dag')
-        ti = TI(task, DEFAULT_DATE)
+    def test_respects_prev_dagrun_dep(self, create_task_instance):
+        ti = create_task_instance()
         failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
         passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')]
         with patch(
@@ -958,38 +923,42 @@ class TestTaskInstance:
             (State.NONE, False),
         ],
     )
-    def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done, create_dummy_dag):
-        dag, task = create_dummy_dag()
+    @provide_session
+    def test_are_dependents_done(
+        self, downstream_ti_state, expected_are_dependents_done, create_task_instance, session=None
+    ):
+        ti = create_task_instance(session=session)
+        dag = ti.task.dag
         downstream_task = DummyOperator(task_id='downstream_task', dag=dag)
-        task >> downstream_task
+        ti.task >> downstream_task
 
-        ti = TI(task, DEFAULT_DATE)
-        downstream_ti = TI(downstream_task, DEFAULT_DATE)
+        downstream_ti = TI(downstream_task, run_id=ti.run_id)
 
-        downstream_ti.set_state(downstream_ti_state)
-        assert ti.are_dependents_done() == expected_are_dependents_done
+        downstream_ti.set_state(downstream_ti_state, session)
+        session.flush()
+        assert ti.are_dependents_done(session) == expected_are_dependents_done
 
-    def test_xcom_pull(self, create_dummy_dag):
+    def test_xcom_pull(self, create_task_instance):
         """
         Test xcom_pull, using different filtering methods.
         """
-        dag, task1 = create_dummy_dag(
+        ti1 = create_task_instance(
             dag_id='test_xcom',
             task_id='test_xcom_1',
-            schedule_interval='@monthly',
             start_date=timezone.datetime(2016, 6, 1, 0, 0, 0),
         )
 
-        exec_date = DEFAULT_DATE
-
         # Push a value
-        ti1 = TI(task=task1, execution_date=exec_date)
         ti1.xcom_push(key='foo', value='bar')
 
         # Push another value with the same key (but by a different task)
-        task2 = DummyOperator(task_id='test_xcom_2', dag=dag)
-        ti2 = TI(task=task2, execution_date=exec_date)
-        ti2.xcom_push(key='foo', value='baz')
+        XCom.set(
+            key='foo',
+            value='baz',
+            task_id='test_xcom_2',
+            dag_id=ti1.dag_id,
+            execution_date=ti1.execution_date,
+        )
 
         # Pull with no arguments
         result = ti1.xcom_pull()
@@ -1007,21 +976,19 @@ class TestTaskInstance:
         result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
         assert result == ['bar', 'baz']
 
-    def test_xcom_pull_after_success(self, create_dummy_dag):
+    def test_xcom_pull_after_success(self, create_task_instance):
         """
         tests xcom set/clear relative to a task in a 'success' rerun scenario
         """
         key = 'xcom_key'
         value = 'xcom_value'
 
-        _, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_xcom',
             schedule_interval='@monthly',
             task_id='test_xcom',
             pool='test_xcom',
         )
-        exec_date = DEFAULT_DATE
-        ti = TI(task=task, execution_date=exec_date)
 
         ti.run(mark_success=True)
         ti.xcom_push(key=key, value=value)
@@ -1039,7 +1006,7 @@ class TestTaskInstance:
         ti.run(ignore_all_deps=True)
         assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
 
-    def test_xcom_pull_different_execution_date(self, create_dummy_dag):
+    def test_xcom_pull_different_execution_date(self, create_task_instance):
         """
         tests xcom fetch behavior with different execution dates, using
         both xcom_pull with "include_prior_dates" and without
@@ -1047,21 +1014,21 @@ class TestTaskInstance:
         key = 'xcom_key'
         value = 'xcom_value'
 
-        dag, task = create_dummy_dag(
+        ti = create_task_instance(
             dag_id='test_xcom',
             schedule_interval='@monthly',
             task_id='test_xcom',
             pool='test_xcom',
         )
-        exec_date = DEFAULT_DATE
-        ti = TI(task=task, execution_date=exec_date)
+        exec_date = ti.dag_run.execution_date
 
         ti.run(mark_success=True)
         ti.xcom_push(key=key, value=value)
         assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
         ti.run()
         exec_date += datetime.timedelta(days=1)
-        ti = TI(task=task, execution_date=exec_date)
+        dr = ti.task.dag.create_dagrun(run_id="test2", execution_date=exec_date, state=None)
+        ti = TI(task=ti.task, run_id=dr.run_id)
         ti.run()
         # We have set a new execution date (and did not pass in
         # 'include_prior_dates'which means this task should now have a cleared
@@ -1084,8 +1051,8 @@ class TestTaskInstance:
                 python_callable=lambda: value,
                 do_xcom_push=False,
             )
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
-        dag_maker.create_dagrun()
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         ti.run()
         assert ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY) is None
 
@@ -1108,33 +1075,31 @@ class TestTaskInstance:
                 task_id='test_operator',
                 python_callable=lambda: 'error',
             )
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
-        dag_maker.create_dagrun()
+        ti = dag_maker.create_dagrun(execution_date=DEFAULT_DATE).task_instances[0]
+        ti.task = task
         with pytest.raises(TestError):
             ti.run()
 
-    def test_check_and_change_state_before_execution(self, create_dummy_dag):
-        _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+    def test_check_and_change_state_before_execution(self, create_task_instance):
+        ti = create_task_instance(dag_id='test_check_and_change_state_before_execution')
         assert ti._try_number == 0
         assert ti.check_and_change_state_before_execution()
         # State should be running, and try_number column should be incremented
         assert ti.state == State.RUNNING
         assert ti._try_number == 1
 
-    def test_check_and_change_state_before_execution_dep_not_met(self, create_dummy_dag):
-        dag, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
-        task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
-        task >> task2
-        ti = TI(task=task2, execution_date=timezone.utcnow())
+    def test_check_and_change_state_before_execution_dep_not_met(self, create_task_instance):
+        ti = create_task_instance(dag_id='test_check_and_change_state_before_execution')
+        task2 = DummyOperator(task_id='task2', dag=ti.task.dag, start_date=DEFAULT_DATE)
+        ti.task >> task2
+        ti = TI(task=task2, run_id=ti.run_id)
         assert not ti.check_and_change_state_before_execution()
 
-    def test_try_number(self, create_dummy_dag):
+    def test_try_number(self, create_task_instance):
         """
         Test the try_number accessor behaves in various running states
         """
-        _, task = create_dummy_dag(dag_id='test_check_and_change_state_before_execution')
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = create_task_instance(dag_id='test_check_and_change_state_before_execution')
         assert 1 == ti.try_number
         ti.try_number = 2
         ti.state = State.RUNNING
@@ -1142,20 +1107,33 @@ class TestTaskInstance:
         ti.state = State.SUCCESS
         assert 3 == ti.try_number
 
-    def test_get_num_running_task_instances(self, create_dummy_dag):
+    def test_get_num_running_task_instances(self, create_task_instance):
         session = settings.Session()
 
-        _, task = create_dummy_dag(dag_id='test_get_num_running_task_instances', task_id='task1')
-        _, task2 = create_dummy_dag(dag_id='test_get_num_running_task_instances_dummy', task_id='task2')
-        ti1 = TI(task=task, execution_date=DEFAULT_DATE)
-        ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
-        ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
+        ti1 = create_task_instance(
+            dag_id='test_get_num_running_task_instances', task_id='task1', session=session
+        )
+
+        dr = ti1.task.dag.create_dagrun(
+            execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+            state=None,
+            run_id='2',
+            session=session,
+        )
+        assert ti1 in session
+        ti2 = dr.task_instances[0]
+        ti2.task = ti1.task
+
+        ti3 = create_task_instance(
+            dag_id='test_get_num_running_task_instances_dummy', task_id='task2', session=session
+        )
+        assert ti3 in session
+        assert ti1 in session
+
         ti1.state = State.RUNNING
         ti2.state = State.QUEUED
         ti3.state = State.RUNNING
-        session.merge(ti1)
-        session.merge(ti2)
-        session.merge(ti3)
+        assert ti3 in session
         session.commit()
 
         assert 1 == ti1.get_num_running_task_instances(session=session)
@@ -1174,9 +1152,8 @@ class TestTaskInstance:
     #     self.assertEqual(d['task_id'][0], 'op')
     #     self.assertEqual(pendulum.parse(d['execution_date'][0]), now)
 
-    def test_log_url(self, create_dummy_dag):
-        _, task = create_dummy_dag('dag', task_id='op')
-        ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1))
+    def test_log_url(self, create_task_instance):
+        ti = create_task_instance(dag_id='dag', task_id='op', execution_date=timezone.datetime(2018, 1, 1))
 
         expected_url = (
             'http://localhost:8080/log?'
@@ -1186,10 +1163,9 @@ class TestTaskInstance:
         )
         assert ti.log_url == expected_url
 
-    def test_mark_success_url(self, create_dummy_dag):
+    def test_mark_success_url(self, create_task_instance):
         now = pendulum.now('Europe/Brussels')
-        _, task = create_dummy_dag('dag', task_id='op')
-        ti = TI(task=task, execution_date=now)
+        ti = create_task_instance(dag_id='dag', task_id='op', execution_date=now)
         query = urllib.parse.parse_qs(
             urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True
         )
@@ -1197,10 +1173,9 @@ class TestTaskInstance:
         assert query['task_id'][0] == 'op'
         assert pendulum.parse(query['execution_date'][0]) == now
 
-    def test_overwrite_params_with_dag_run_conf(self):
-        task = DummyOperator(task_id='op')
-        ti = TI(task=task, execution_date=datetime.datetime.now())
-        dag_run = DagRun()
+    def test_overwrite_params_with_dag_run_conf(self, create_task_instance):
+        ti = create_task_instance()
+        dag_run = ti.dag_run
         dag_run.conf = {"override": True}
         params = {"override": False}
 
@@ -1208,20 +1183,18 @@ class TestTaskInstance:
 
         assert params["override"] is True
 
-    def test_overwrite_params_with_dag_run_none(self):
-        task = DummyOperator(task_id='op')
-        ti = TI(task=task, execution_date=datetime.datetime.now())
+    def test_overwrite_params_with_dag_run_none(self, create_task_instance):
+        ti = create_task_instance()
         params = {"override": False}
 
         ti.overwrite_params_with_dag_run_conf(params, None)
 
         assert params["override"] is False
 
-    def test_overwrite_params_with_dag_run_conf_none(self):
-        task = DummyOperator(task_id='op')
-        ti = TI(task=task, execution_date=datetime.datetime.now())
+    def test_overwrite_params_with_dag_run_conf_none(self, create_task_instance):
+        ti = create_task_instance()
         params = {"override": False}
-        dag_run = DagRun()
+        dag_run = ti.dag_run
 
         ti.overwrite_params_with_dag_run_conf(params, dag_run)
 
@@ -1231,8 +1204,8 @@ class TestTaskInstance:
     def test_email_alert(self, mock_send_email, dag_maker):
         with dag_maker(dag_id='test_failure_email'):
             task = BashOperator(task_id='test_email_alert', bash_command='exit 1', email='to')
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
 
         try:
             ti.run()
@@ -1259,8 +1232,8 @@ class TestTaskInstance:
                 bash_command='exit 1',
                 email='to',
             )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
 
         opener = mock_open(read_data='template: {{ti.task_id}}')
         with patch('airflow.models.taskinstance.open', opener, create=True):
@@ -1276,10 +1249,7 @@ class TestTaskInstance:
 
     def test_set_duration(self):
         task = DummyOperator(task_id='op', email='test@test.test')
-        ti = TI(
-            task=task,
-            execution_date=datetime.datetime.now(),
-        )
+        ti = TI(task=task)
         ti.start_date = datetime.datetime(2018, 10, 1, 1)
         ti.end_date = datetime.datetime(2018, 10, 1, 2)
         ti.set_duration()
@@ -1287,20 +1257,19 @@ class TestTaskInstance:
 
     def test_set_duration_empty_dates(self):
         task = DummyOperator(task_id='op', email='test@test.test')
-        ti = TI(task=task, execution_date=datetime.datetime.now())
+        ti = TI(task=task)
         ti.set_duration()
         assert ti.duration is None
 
-    def test_success_callback_no_race_condition(self, create_dummy_dag):
+    def test_success_callback_no_race_condition(self, create_task_instance):
         callback_wrapper = CallbackWrapper()
-        _, task = create_dummy_dag(
-            'test_success_callback_no_race_condition',
+        ti = create_task_instance(
             on_success_callback=callback_wrapper.success_handler,
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            execution_date=timezone.utcnow(),
+            state=State.RUNNING,
         )
 
-        ti = TI(task=task, execution_date=datetime.datetime.now())
-        ti.state = State.RUNNING
         session = settings.Session()
         session.merge(ti)
         session.commit()
@@ -1321,30 +1290,29 @@ class TestTaskInstance:
         with dag_maker(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup):
             task = DummyOperator(task_id='task')
 
-        def get_test_ti(session, execution_date: pendulum.DateTime, state: str) -> TI:
-            dag_maker.create_dagrun(
+        def get_test_ti(execution_date: pendulum.DateTime, state: str) -> TI:
+            dr = dag_maker.create_dagrun(
+                run_id=f'test__{execution_date.isoformat()}',
                 run_type=DagRunType.SCHEDULED,
                 state=state,
                 execution_date=execution_date,
                 start_date=pendulum.now('UTC'),
-                session=session,
             )
-            ti = TI(task=task, execution_date=execution_date)
-            ti.set_state(state=State.SUCCESS, session=session)
+            ti = dr.task_instances[0]
+            ti.task = task
+            ti.set_state(state=State.SUCCESS, session=dag_maker.session)
             return ti
 
-        with create_session() as session:  # type: Session
+        date = cast(pendulum.DateTime, pendulum.parse('2019-01-01T00:00:00+00:00'))
 
-            date = cast(pendulum.DateTime, pendulum.parse('2019-01-01T00:00:00+00:00'))
+        ret = []
 
-            ret = []
+        for idx, state in enumerate(scenario):
+            new_date = date.add(days=idx)
+            ti = get_test_ti(new_date, state)
+            ret.append(ti)
 
-            for idx, state in enumerate(scenario):
-                new_date = date.add(days=idx)
-                ti = get_test_ti(session, new_date, state)
-                ret.append(ti)
-
-            return ret
+        return ret
 
     _prev_dates_param_list = [
         pytest.param('0 0 * * * ', True, id='cron/catchup'),
@@ -1364,9 +1332,9 @@ class TestTaskInstance:
 
         assert ti_list[0].get_previous_ti() is None
 
-        assert ti_list[2].get_previous_ti().execution_date == ti_list[1].execution_date
+        assert ti_list[2].get_previous_ti().run_id == ti_list[1].run_id
 
-        assert ti_list[2].get_previous_ti().execution_date != ti_list[0].execution_date
+        assert ti_list[2].get_previous_ti().run_id != ti_list[0].run_id
 
     @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list)
     def test_previous_ti_success(self, schedule_interval, catchup, dag_maker) -> None:
@@ -1378,9 +1346,9 @@ class TestTaskInstance:
         assert ti_list[0].get_previous_ti(state=State.SUCCESS) is None
         assert ti_list[1].get_previous_ti(state=State.SUCCESS) is None
 
-        assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date == ti_list[1].execution_date
+        assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id == ti_list[1].run_id
 
-        assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date != ti_list[2].execution_date
+        assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id != ti_list[2].run_id
 
     @pytest.mark.parametrize("schedule_interval, catchup", _prev_dates_param_list)
     def test_previous_execution_date_success(self, schedule_interval, catchup, dag_maker) -> None:
@@ -1388,6 +1356,9 @@ class TestTaskInstance:
         scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
 
         ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario, dag_maker)
+        # vivify
+        for ti in ti_list:
+            ti.execution_date
 
         assert ti_list[0].get_previous_execution_date(state=State.SUCCESS) is None
         assert ti_list[1].get_previous_execution_date(state=State.SUCCESS) is None
@@ -1439,23 +1410,13 @@ class TestTaskInstance:
         assert ti_2.get_previous_start_date() == ti_1.start_date
         assert ti_1.start_date is None
 
-    def test_pendulum_template_dates(self, create_dummy_dag):
-        dag, task = create_dummy_dag(
+    def test_pendulum_template_dates(self, create_task_instance):
+        ti = create_task_instance(
             dag_id='test_pendulum_template_dates',
             task_id='test_pendulum_template_dates_task',
             schedule_interval='0 12 * * *',
         )
 
-        execution_date = timezone.utcnow()
-
-        dag.create_dagrun(
-            execution_date=execution_date,
-            state=State.RUNNING,
-            run_type=DagRunType.MANUAL,
-        )
-
-        ti = TI(task=task, execution_date=execution_date)
-
         template_context = ti.get_template_context()
 
         assert isinstance(template_context["data_interval_start"], pendulum.DateTime)
@@ -1474,7 +1435,7 @@ class TestTaskInstance:
             ('{{ conn.a_connection.extra_dejson.extra__asana__workspace }}', 'extra1'),
         ],
     )
-    def test_template_with_connection(self, content, expected_output, create_dummy_dag):
+    def test_template_with_connection(self, content, expected_output, create_task_instance):
         """
         Test the availability of variables in templates
         """
@@ -1496,11 +1457,10 @@ class TestTaskInstance:
                 session,
             )
 
-        _, task = create_dummy_dag()
+        ti = create_task_instance()
 
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
         context = ti.get_template_context()
-        result = task.render_template(content, context)
+        result = ti.task.render_template(content, context)
         assert result == expected_output
 
     @pytest.mark.parametrize(
@@ -1512,29 +1472,25 @@ class TestTaskInstance:
             ('{{ var.value.get("missing_variable", "fallback") }}', 'fallback'),
         ],
     )
-    def test_template_with_variable(self, content, expected_output, create_dummy_dag):
+    def test_template_with_variable(self, content, expected_output, create_task_instance):
         """
         Test the availability of variables in templates
         """
         Variable.set('a_variable', 'a test value')
 
-        _, task = create_dummy_dag()
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = create_task_instance()
         context = ti.get_template_context()
-        result = task.render_template(content, context)
+        result = ti.task.render_template(content, context)
         assert result == expected_output
 
-    def test_template_with_variable_missing(self, create_dummy_dag):
+    def test_template_with_variable_missing(self, create_task_instance):
         """
         Test the availability of variables in templates
         """
-        _, task = create_dummy_dag()
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = create_task_instance()
         context = ti.get_template_context()
         with pytest.raises(KeyError):
-            task.render_template('{{ var.value.get("missing_variable") }}', context)
+            ti.task.render_template('{{ var.value.get("missing_variable") }}', context)
 
     @pytest.mark.parametrize(
         "content, expected_output",
@@ -1546,28 +1502,24 @@ class TestTaskInstance:
             ('{{ var.json.get("missing_variable", {"a": {"test": "fallback"}})["a"]["test"] }}', 'fallback'),
         ],
     )
-    def test_template_with_json_variable(self, content, expected_output, create_dummy_dag):
+    def test_template_with_json_variable(self, content, expected_output, create_task_instance):
         """
         Test the availability of variables in templates
         """
         Variable.set('a_variable', {'a': {'test': 'value'}}, serialize_json=True)
 
-        _, task = create_dummy_dag()
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = create_task_instance()
         context = ti.get_template_context()
-        result = task.render_template(content, context)
+        result = ti.task.render_template(content, context)
         assert result == expected_output
 
-    def test_template_with_json_variable_missing(self, create_dummy_dag):
-        _, task = create_dummy_dag()
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+    def test_template_with_json_variable_missing(self, create_task_instance):
+        ti = create_task_instance()
         context = ti.get_template_context()
         with pytest.raises(KeyError):
-            task.render_template('{{ var.json.get("missing_variable") }}', context)
+            ti.task.render_template('{{ var.json.get("missing_variable") }}', context)
 
-    def test_execute_callback(self, create_dummy_dag):
+    def test_execute_callback(self, create_task_instance):
         called = False
 
         def on_execute_callable(context):
@@ -1575,14 +1527,12 @@ class TestTaskInstance:
             called = True
             assert context['dag_run'].dag_id == 'test_dagrun_execute_callback'
 
-        _, task = create_dummy_dag(
-            'test_execute_callback',
+        ti = create_task_instance(
+            dag_id='test_execute_callback',
             on_execute_callback=on_execute_callable,
-            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            state=State.RUNNING,
         )
 
-        ti = TI(task=task, execution_date=datetime.datetime.now())
-        ti.state = State.RUNNING
         session = settings.Session()
 
         session.merge(ti)
@@ -1601,7 +1551,9 @@ class TestTaskInstance:
             (State.FAILED, "Error when executing on_failure_callback"),
         ],
     )
-    def test_finished_callbacks_handle_and_log_exception(self, finished_state, expected_message, dag_maker):
+    def test_finished_callbacks_handle_and_log_exception(
+        self, finished_state, expected_message, create_task_instance
+    ):
         called = completed = False
 
         def on_finish_callable(context):
@@ -1610,28 +1562,24 @@ class TestTaskInstance:
             raise KeyError
             completed = True
 
-        with dag_maker(
-            'test_success_callback_handles_exception',
+        ti = create_task_instance(
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        ):
-            task = DummyOperator(
-                task_id='op',
-                on_success_callback=on_finish_callable,
-                on_retry_callback=on_finish_callable,
-                on_failure_callback=on_finish_callable,
-            )
-        dag_maker.create_dagrun()
-        ti = TI(task=task, execution_date=datetime.datetime.now())
+            on_success_callback=on_finish_callable,
+            on_retry_callback=on_finish_callable,
+            on_failure_callback=on_finish_callable,
+            state=finished_state,
+        )
         ti._log = mock.Mock()
-        ti.state = finished_state
         ti._run_finished_callback()
 
         assert called
         assert not completed
         ti.log.exception.assert_called_once_with(expected_message)
 
-    def test_handle_failure(self, create_dummy_dag):
+    @provide_session
+    def test_handle_failure(self, create_dummy_dag, session=None):
         start_date = timezone.datetime(2016, 6, 1)
+        clear_db_runs()
 
         mock_on_failure_1 = mock.MagicMock()
         mock_on_retry_1 = mock.MagicMock()
@@ -1642,8 +1590,12 @@ class TestTaskInstance:
             task_id="test_handle_failure_on_failure",
             on_failure_callback=mock_on_failure_1,
             on_retry_callback=mock_on_retry_1,
+            session=session,
         )
-        ti1 = TI(task=task1, execution_date=start_date)
+        dr = dag.create_dagrun(run_id="test2", execution_date=timezone.utcnow(), state=None, session=session)
+
+        ti1 = dr.get_task_instance(task1.task_id, session=session)
+        ti1.task = task1
         ti1.state = State.FAILED
         ti1.handle_failure("test failure handling")
         ti1._run_finished_callback()
@@ -1661,8 +1613,10 @@ class TestTaskInstance:
             retries=1,
             dag=dag,
         )
-        ti2 = TI(task=task2, execution_date=start_date)
+        ti2 = TI(task=task2, run_id=dr.run_id)
         ti2.state = State.FAILED
+        session.add(ti2)
+        session.flush()
         ti2.handle_failure("test retry handling")
         ti2._run_finished_callback()
 
@@ -1681,7 +1635,9 @@ class TestTaskInstance:
             retries=1,
             dag=dag,
         )
-        ti3 = TI(task=task3, execution_date=start_date)
+        ti3 = TI(task=task3, run_id=dr.run_id)
+        session.add(ti3)
+        session.flush()
         ti3.state = State.FAILED
         ti3.handle_failure("test force_fail handling", force_fail=True)
         ti3._run_finished_callback()
@@ -1700,7 +1656,8 @@ class TestTaskInstance:
                 python_callable=fail,
                 retries=1,
             )
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         try:
             ti.run()
         except AirflowFailException:
@@ -1717,7 +1674,8 @@ class TestTaskInstance:
                 python_callable=fail,
                 retries=1,
             )
-        ti = TI(task=task, execution_date=timezone.utcnow())
+        ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
+        ti.task = task
         try:
             ti.run()
         except AirflowException:
@@ -1737,11 +1695,11 @@ class TestTaskInstance:
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
         ):
             op = PythonOperator(task_id='hive_in_python_op', python_callable=self._env_var_check_callback)
-        dag_maker.create_dagrun(
+        dr = dag_maker.create_dagrun(
             run_type=DagRunType.MANUAL,
             external_trigger=False,
         )
-        ti = TI(task=op, execution_date=DEFAULT_DATE)
+        ti = TI(task=op, run_id=dr.run_id)
         ti.state = State.RUNNING
         session = settings.Session()
         session.merge(ti)
@@ -1751,38 +1709,33 @@ class TestTaskInstance:
         assert ti.state == State.SUCCESS
 
     @patch.object(Stats, 'incr')
-    def test_task_stats(self, stats_mock, create_dummy_dag):
-        dag, op = create_dummy_dag(
-            'test_task_start_end_stats',
+    def test_task_stats(self, stats_mock, create_task_instance):
+        ti = create_task_instance(
+            dag_id='test_task_start_end_stats',
             end_date=DEFAULT_DATE + datetime.timedelta(days=10),
+            state=State.RUNNING,
         )
+        stats_mock.reset_mock()
 
-        ti = TI(task=op, execution_date=DEFAULT_DATE)
-        ti.state = State.RUNNING
         session = settings.Session()
         session.merge(ti)
         session.commit()
         ti._run_raw_task()
         ti.refresh_from_db()
-        stats_mock.assert_called_with(f'ti.finish.{dag.dag_id}.{op.task_id}.{ti.state}')
-        assert call(f'ti.start.{dag.dag_id}.{op.task_id}') in stats_mock.mock_calls
-        assert stats_mock.call_count == 5
+        stats_mock.assert_called_with(f'ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}')
+        assert call(f'ti.start.{ti.dag_id}.{ti.task_id}') in stats_mock.mock_calls
+        assert stats_mock.call_count == 4
 
-    def test_command_as_list(self, dag_maker):
-        with dag_maker(
-            'test_dag',
-            end_date=DEFAULT_DATE + datetime.timedelta(days=10),
-        ) as dag:
-            op = DummyOperator(task_id='dummy_op', dag=dag)
-        dag.fileloc = os.path.join(TEST_DAGS_FOLDER, 'x.py')
-        ti = TI(task=op, execution_date=DEFAULT_DATE)
+    def test_command_as_list(self, create_task_instance):
+        ti = create_task_instance()
+        ti.task.dag.fileloc = os.path.join(TEST_DAGS_FOLDER, 'x.py')
         assert ti.command_as_list() == [
             'airflow',
             'tasks',
             'run',
-            dag.dag_id,
-            op.task_id,
-            DEFAULT_DATE.isoformat(),
+            ti.dag_id,
+            ti.task_id,
+            ti.run_id,
             '--subdir',
             'DAGS_FOLDER/x.py',
         ]
@@ -1811,22 +1764,23 @@ class TestTaskInstance:
         )
         assert assert_command == generate_command
 
-    def test_get_rendered_template_fields(self, dag_maker):
+    @provide_session
+    def test_get_rendered_template_fields(self, dag_maker, session=None):
 
-        with dag_maker('test-dag') as dag:
+        with dag_maker('test-dag', session=session) as dag:
             task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
         dag.fileloc = TEST_DAGS_FOLDER + '/test_get_k8s_pod_yaml.py'
+        ti = dag_maker.create_dagrun().task_instances[0]
+        ti.task = task
 
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
-
-        with create_session() as session:
-            session.add(RenderedTaskInstanceFields(ti))
+        session.add(RenderedTaskInstanceFields(ti))
+        session.flush()
 
         # Create new TI for the same Task
         new_task = BashOperator(task_id='op12', bash_command="{{ task.task_id }}", dag=dag)
 
-        new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
-        new_ti.get_rendered_template_fields()
+        new_ti = TI(task=new_task, run_id=ti.run_id)
+        new_ti.get_rendered_template_fields(session=session)
 
         assert "op1" == ti.task.bash_command
 
@@ -1836,17 +1790,18 @@ class TestTaskInstance:
 
     @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
     @mock.patch("airflow.settings.pod_mutation_hook")
-    def test_render_k8s_pod_yaml(self, pod_mutation_hook, dag_maker):
-        with dag_maker('test_get_rendered_k8s_spec'):
-            task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
-        dr = dag_maker.create_dagrun(run_id='test_run_id')
-        ti = dr.get_task_instance(task.task_id)
-        ti.task = task
+    def test_render_k8s_pod_yaml(self, pod_mutation_hook, create_task_instance):
+        ti = create_task_instance(
+            dag_id='test_render_k8s_pod_yaml',
+            run_id='test_run_id',
+            task_id='op1',
+            execution_date=DEFAULT_DATE,
+        )
 
         expected_pod_spec = {
             'metadata': {
                 'annotations': {
-                    'dag_id': 'test_get_rendered_k8s_spec',
+                    'dag_id': 'test_render_k8s_pod_yaml',
                     'execution_date': '2016-01-01T00:00:00+00:00',
                     'task_id': 'op1',
                     'try_number': '1',
@@ -1854,7 +1809,7 @@ class TestTaskInstance:
                 'labels': {
                     'airflow-worker': 'worker-config',
                     'airflow_version': version,
-                    'dag_id': 'test_get_rendered_k8s_spec',
+                    'dag_id': 'test_render_k8s_pod_yaml',
                     'execution_date': '2016-01-01T00_00_00_plus_00_00',
                     'kubernetes_executor': 'True',
                     'task_id': 'op1',
@@ -1870,7 +1825,7 @@ class TestTaskInstance:
                             'airflow',
                             'tasks',
                             'run',
-                            'test_get_rendered_k8s_spec',
+                            'test_render_k8s_pod_yaml',
                             'op1',
                             'test_run_id',
                             '--subdir',
@@ -1889,12 +1844,9 @@ class TestTaskInstance:
 
     @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
     @mock.patch.object(RenderedTaskInstanceFields, 'get_k8s_pod_yaml')
-    def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, dag_maker):
+    def test_get_rendered_k8s_spec(self, rtif_get_k8s_pod_yaml, create_task_instance):
         # Create new TI for the same Task
-        with dag_maker('test_get_rendered_k8s_spec'):
-            task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
-
-        ti = TI(task=task, execution_date=DEFAULT_DATE)
+        ti = create_task_instance()
 
         patcher = mock.patch.object(ti, 'render_k8s_pod_yaml', autospec=True)
 
@@ -1917,10 +1869,9 @@ class TestTaskInstance:
 
             render_k8s_pod_yaml.assert_called_once()
 
-    def test_set_state_up_for_retry(self, create_dummy_dag):
-        dag, op1 = create_dummy_dag('dag')
+    def test_set_state_up_for_retry(self, create_task_instance):
+        ti = create_task_instance(state=State.RUNNING)
 
-        ti = TI(task=op1, execution_date=timezone.utcnow(), state=State.RUNNING)
         start_date = timezone.utcnow()
         ti.start_date = start_date
 
@@ -1930,13 +1881,13 @@ class TestTaskInstance:
         assert ti.start_date < ti.end_date
         assert ti.duration > 0
 
-    def test_refresh_from_db(self):
+    def test_refresh_from_db(self, create_task_instance):
         run_date = timezone.utcnow()
 
         expected_values = {
             "task_id": "test_refresh_from_db_task",
             "dag_id": "test_refresh_from_db_dag",
-            "execution_date": run_date,
+            "run_id": "test",
             "start_date": run_date + datetime.timedelta(days=1),
             "end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234),
             "duration": 1.234,
@@ -1968,8 +1919,7 @@ class TestTaskInstance:
             "This prevents refresh_from_db() from missing a field."
         )
 
-        operator = DummyOperator(task_id=expected_values['task_id'])
-        ti = TI(task=operator, execution_date=expected_values['execution_date'])
+        ti = create_task_instance(task_id=expected_values['task_id'], dag_id=expected_values['dag_id'])
         for key, expected_value in expected_values.items():
             setattr(ti, key, expected_value)
         with create_session() as session:
@@ -1980,7 +1930,7 @@ class TestTaskInstance:
         mock_task.task_id = expected_values["task_id"]
         mock_task.dag_id = expected_values["dag_id"]
 
-        ti = TI(task=mock_task, execution_date=run_date)
... 7687 lines suppressed ...