You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@dolphinscheduler.apache.org by zh...@apache.org on 2022/09/18 08:50:56 UTC

[dolphinscheduler] 04/05: [feat][python] Support MLflow task in python api (#11962)

This is an automated email from the ASF dual-hosted git repository.

zhongjiajie pushed a commit to branch 3.1.0-prepare
in repository https://gitbox.apache.org/repos/asf/dolphinscheduler.git

commit cb063732d7b89a9029c74019de06e4aa1e62adb9
Author: JieguangZhou <ji...@163.com>
AuthorDate: Sun Sep 18 16:28:18 2022 +0800

    [feat][python] Support MLflow task in python api (#11962)
    
    (cherry picked from commit ad683c3c428876916d9550111fc9bbb77afd9e9f)
---
 .../pydolphinscheduler/docs/source/tasks/index.rst |   1 +
 .../docs/source/tasks/{index.rst => mlflow.rst}    |  54 ++---
 .../examples/yaml_define/mlflow.yaml               |  78 ++++++
 .../src/pydolphinscheduler/constants.py            |   1 +
 .../examples/task_mlflow_example.py                | 104 ++++++++
 .../src/pydolphinscheduler/tasks/__init__.py       |  10 +
 .../src/pydolphinscheduler/tasks/mlflow.py         | 265 +++++++++++++++++++++
 .../pydolphinscheduler/tests/tasks/test_mlflow.py  | 211 ++++++++++++++++
 8 files changed, 695 insertions(+), 29 deletions(-)

diff --git a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
index 5b9c165700..0cee3f1d6b 100644
--- a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
+++ b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
@@ -42,5 +42,6 @@ In this section
    sub_process
 
    sagemaker
+   mlflow
    openmldb
    pytorch
diff --git a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst
similarity index 61%
copy from dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
copy to dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst
index 5b9c165700..b83903c26f 100644
--- a/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/index.rst
+++ b/dolphinscheduler-python/pydolphinscheduler/docs/source/tasks/mlflow.rst
@@ -15,32 +15,28 @@
    specific language governing permissions and limitations
    under the License.
 
-Tasks
-=====
-
-In this section 
-
-.. toctree::
-   :maxdepth: 1
-   
-   func_wrap
-   shell
-   sql
-   python
-   http
-
-   switch
-   condition
-   dependent
-
-   spark
-   flink
-   map_reduce
-   procedure
-
-   datax
-   sub_process
-
-   sagemaker
-   openmldb
-   pytorch
+MLflow
+=========
+
+
+A MLflow task type's example and dive into information of **PyDolphinScheduler**.
+
+Example
+-------
+
+.. literalinclude:: ../../../src/pydolphinscheduler/examples/task_mlflow_example.py
+   :start-after: [start workflow_declare]
+   :end-before: [end workflow_declare]
+
+Dive Into
+---------
+
+.. automodule:: pydolphinscheduler.tasks.mlflow
+
+
+YAML file example
+-----------------
+
+.. literalinclude:: ../../../examples/yaml_define/mlflow.yaml
+   :start-after: # under the License.
+   :language: yaml
diff --git a/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml b/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml
new file mode 100644
index 0000000000..232442a186
--- /dev/null
+++ b/dolphinscheduler-python/pydolphinscheduler/examples/yaml_define/mlflow.yaml
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+# Define variable `mlflow_tracking_uri`
+mlflow_tracking_uri: &mlflow_tracking_uri "http://127.0.0.1:5000" 
+
+# Define the workflow
+workflow:
+  name: "MLflow"
+
+# Define the tasks under the workflow
+tasks:
+  - name: train_xgboost_native
+    task_type: MLFlowProjectsCustom 
+    repository: https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    parameters: -P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9
+    experiment_name: xgboost
+
+
+  - name: deploy_mlflow
+    deps: [train_xgboost_native]
+    task_type: MLflowModels 
+    model_uri: models:/xgboost_native/Production
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    deploy_mode: MLFLOW
+    port: 7001
+
+  - name: train_automl
+    task_type: MLFlowProjectsAutoML 
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    parameters: time_budget=30;estimator_list=['lgbm']
+    experiment_name: automl_iris
+    model_name: iris_A
+    automl_tool: flaml
+    data_path: /data/examples/iris
+
+  - name: deploy_docker
+    task_type: MLflowModels 
+    deps: [train_automl]
+    model_uri: models:/iris_A/Production
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    deploy_mode: DOCKER
+    port: 7002
+
+  - name: train_basic_algorithm
+    task_type: MLFlowProjectsBasicAlgorithm 
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    parameters: n_estimators=200;learning_rate=0.2
+    experiment_name: basic_algorithm_iris
+    model_name: iris_B
+    algorithm: lightgbm
+    data_path: /data/examples/iris
+    search_params: max_depth=[5, 10];n_estimators=[100, 200]
+
+
+  - name: deploy_docker_compose
+    task_type: MLflowModels 
+    deps: [train_basic_algorithm]
+    model_uri: models:/iris_B/Production
+    mlflow_tracking_uri: *mlflow_tracking_uri
+    deploy_mode: DOCKER COMPOSE
+    port: 7003
diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py
index d8d2febfeb..729d48a9ad 100644
--- a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py
+++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/constants.py
@@ -58,6 +58,7 @@ class TaskType(str):
     SPARK = "SPARK"
     MR = "MR"
     SAGEMAKER = "SAGEMAKER"
+    MLFLOW = "MLFLOW"
     OPENMLDB = "OPENMLDB"
     PYTORCH = "PYTORCH"
 
diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py
new file mode 100644
index 0000000000..328688e646
--- /dev/null
+++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/examples/task_mlflow_example.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# [start workflow_declare]
+"""A example workflow for task mlflow."""
+
+from pydolphinscheduler.core.process_definition import ProcessDefinition
+from pydolphinscheduler.tasks.mlflow import (
+    MLflowDeployType,
+    MLflowModels,
+    MLFlowProjectsAutoML,
+    MLFlowProjectsBasicAlgorithm,
+    MLFlowProjectsCustom,
+)
+
+mlflow_tracking_uri = "http://127.0.0.1:5000"
+
+with ProcessDefinition(
+    name="task_mlflow_example",
+    tenant="tenant_exists",
+) as pd:
+
+    # run custom mlflow project to train model
+    train_custom = MLFlowProjectsCustom(
+        name="train_xgboost_native",
+        repository="https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        parameters="-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9",
+        experiment_name="xgboost",
+    )
+
+    # Using MLFLOW to deploy model from custom mlflow project
+    deploy_mlflow = MLflowModels(
+        name="deploy_mlflow",
+        model_uri="models:/xgboost_native/Production",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        deploy_mode=MLflowDeployType.MLFLOW,
+        port=7001,
+    )
+
+    train_custom >> deploy_mlflow
+
+    # run automl to train model
+    train_automl = MLFlowProjectsAutoML(
+        name="train_automl",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        parameters="time_budget=30;estimator_list=['lgbm']",
+        experiment_name="automl_iris",
+        model_name="iris_A",
+        automl_tool="flaml",
+        data_path="/data/examples/iris",
+    )
+
+    # Using DOCKER to deploy model from train_automl
+    deploy_docker = MLflowModels(
+        name="deploy_docker",
+        model_uri="models:/iris_A/Production",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        deploy_mode=MLflowDeployType.DOCKER,
+        port=7002,
+    )
+
+    train_automl >> deploy_docker
+
+    # run lightgbm to train model
+    train_basic_algorithm = MLFlowProjectsBasicAlgorithm(
+        name="train_basic_algorithm",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        parameters="n_estimators=200;learning_rate=0.2",
+        experiment_name="basic_algorithm_iris",
+        model_name="iris_B",
+        algorithm="lightgbm",
+        data_path="/data/examples/iris",
+        search_params="max_depth=[5, 10];n_estimators=[100, 200]",
+    )
+
+    # Using DOCKER COMPOSE to deploy model from train_basic_algorithm
+    deploy_docker_compose = MLflowModels(
+        name="deploy_docker_compose",
+        model_uri="models:/iris_B/Production",
+        mlflow_tracking_uri=mlflow_tracking_uri,
+        deploy_mode=MLflowDeployType.DOCKER_COMPOSE,
+        port=7003,
+    )
+
+    train_basic_algorithm >> deploy_docker_compose
+
+    pd.submit()
+
+# [end workflow_declare]
diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py
index e5b263c7c2..f3fae203d3 100644
--- a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py
+++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/__init__.py
@@ -23,6 +23,12 @@ from pydolphinscheduler.tasks.dependent import Dependent
 from pydolphinscheduler.tasks.flink import Flink
 from pydolphinscheduler.tasks.http import Http
 from pydolphinscheduler.tasks.map_reduce import MR
+from pydolphinscheduler.tasks.mlflow import (
+    MLflowModels,
+    MLFlowProjectsAutoML,
+    MLFlowProjectsBasicAlgorithm,
+    MLFlowProjectsCustom,
+)
 from pydolphinscheduler.tasks.openmldb import OpenMLDB
 from pydolphinscheduler.tasks.procedure import Procedure
 from pydolphinscheduler.tasks.python import Python
@@ -43,6 +49,10 @@ __all__ = [
     "Http",
     "MR",
     "OpenMLDB",
+    "MLFlowProjectsBasicAlgorithm",
+    "MLFlowProjectsCustom",
+    "MLFlowProjectsAutoML",
+    "MLflowModels",
     "Procedure",
     "Python",
     "Pytorch",
diff --git a/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py
new file mode 100644
index 0000000000..44e6634822
--- /dev/null
+++ b/dolphinscheduler-python/pydolphinscheduler/src/pydolphinscheduler/tasks/mlflow.py
@@ -0,0 +1,265 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Task mlflow."""
+from copy import deepcopy
+from typing import Dict, Optional
+
+from pydolphinscheduler.constants import TaskType
+from pydolphinscheduler.core.task import Task
+
+
+class MLflowTaskType(str):
+    """MLflow task type."""
+
+    MLFLOW_PROJECTS = "MLflow Projects"
+    MLFLOW_MODELS = "MLflow Models"
+
+
+class MLflowJobType(str):
+    """MLflow job type."""
+
+    AUTOML = "AutoML"
+    BASIC_ALGORITHM = "BasicAlgorithm"
+    CUSTOM_PROJECT = "CustomProject"
+
+
+class MLflowDeployType(str):
+    """MLflow deploy type."""
+
+    MLFLOW = "MLFLOW"
+    DOCKER = "DOCKER"
+    DOCKER_COMPOSE = "DOCKER COMPOSE"
+
+
+DEFAULT_MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"
+DEFAULT_VERSION = "master"
+
+
+class BaseMLflow(Task):
+    """Base MLflow task."""
+
+    mlflow_task_type = None
+
+    _task_custom_attr = {
+        "mlflow_tracking_uri",
+        "mlflow_task_type",
+    }
+
+    _child_task_mlflow_attr = set()
+
+    def __init__(self, name: str, mlflow_tracking_uri: str, *args, **kwargs):
+        super().__init__(name, TaskType.MLFLOW, *args, **kwargs)
+        self.mlflow_tracking_uri = mlflow_tracking_uri
+
+    @property
+    def task_params(self) -> Dict:
+        """Return task params."""
+        self._task_custom_attr = deepcopy(self._task_custom_attr)
+        self._task_custom_attr.update(self._child_task_mlflow_attr)
+        return super().task_params
+
+
+class MLflowModels(BaseMLflow):
+    """Task MLflow models object, declare behavior for MLflow models task to dolphinscheduler.
+
+    Deploy machine learning models in diverse serving environments.
+
+    :param name: task name
+    :param model_uri: Model-URI of MLflow , support models:/<model_name>/suffix format and runs:/ format.
+        See https://mlflow.org/docs/latest/tracking.html#artifact-stores
+    :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
+    :param deploy_mode: MLflow deploy mode, support MLFLOW, DOCKER, DOCKER COMPOSE, default is DOCKER
+    :param port: deploy port, default is 7000
+    :param cpu_limit: cpu limit, default is 1.0
+    :param memory_limit: memory limit, default is 500M
+    """
+
+    mlflow_task_type = MLflowTaskType.MLFLOW_MODELS
+
+    _child_task_mlflow_attr = {
+        "deploy_type",
+        "deploy_model_key",
+        "deploy_port",
+        "cpu_limit",
+        "memory_limit",
+    }
+
+    def __init__(
+        self,
+        name: str,
+        model_uri: str,
+        mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
+        deploy_mode: Optional[str] = MLflowDeployType.DOCKER,
+        port: Optional[int] = 7000,
+        cpu_limit: Optional[float] = 1.0,
+        memory_limit: Optional[str] = "500M",
+        *args,
+        **kwargs
+    ):
+        """Init mlflow models task."""
+        super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
+        self.deploy_type = deploy_mode.upper()
+        self.deploy_model_key = model_uri
+        self.deploy_port = port
+        self.cpu_limit = cpu_limit
+        self.memory_limit = memory_limit
+
+
+class MLFlowProjectsCustom(BaseMLflow):
+    """Task MLflow projects object, declare behavior for MLflow Custom projects task to dolphinscheduler.
+
+    :param name: task name
+    :param repository: Repository url of MLflow Project, Support git address and directory on worker.
+        If it's in a subdirectory, We add # to support this (same as mlflow run) ,
+        for example https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native.
+    :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
+    :param experiment_name: MLflow experiment name, default is empty
+    :param parameters: MLflow project parameters, default is empty
+    :param version: MLflow project version, default is master
+
+    """
+
+    mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
+    mlflow_job_type = MLflowJobType.CUSTOM_PROJECT
+
+    _child_task_mlflow_attr = {
+        "mlflow_job_type",
+        "experiment_name",
+        "params",
+        "mlflow_project_repository",
+        "mlflow_project_version",
+    }
+
+    def __init__(
+        self,
+        name: str,
+        repository: str,
+        mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
+        experiment_name: Optional[str] = "",
+        parameters: Optional[str] = "",
+        version: Optional[str] = "master",
+        *args,
+        **kwargs
+    ):
+        """Init mlflow projects task."""
+        super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
+        self.mlflow_project_repository = repository
+        self.experiment_name = experiment_name
+        self.params = parameters
+        self.mlflow_project_version = version
+
+
+class MLFlowProjectsAutoML(BaseMLflow):
+    """Task MLflow projects object, declare behavior for AutoML task to dolphinscheduler.
+
+    :param name: task name
+    :param data_path: data path of MLflow Project, Support git address and directory on worker.
+    :param automl_tool: The AutoML tool used, currently supports autosklearn and flaml.
+    :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
+    :param experiment_name: MLflow experiment name, default is empty
+    :param model_name: MLflow model name, default is empty
+    :param parameters: MLflow project parameters, default is empty
+
+    """
+
+    mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
+    mlflow_job_type = MLflowJobType.AUTOML
+
+    _child_task_mlflow_attr = {
+        "mlflow_job_type",
+        "experiment_name",
+        "model_name",
+        "register_model",
+        "data_path",
+        "params",
+        "automl_tool",
+    }
+
+    def __init__(
+        self,
+        name: str,
+        data_path: str,
+        automl_tool: Optional[str] = "flaml",
+        mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
+        experiment_name: Optional[str] = "",
+        model_name: Optional[str] = "",
+        parameters: Optional[str] = "",
+        *args,
+        **kwargs
+    ):
+        """Init mlflow projects task."""
+        super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
+        self.data_path = data_path
+        self.experiment_name = experiment_name
+        self.model_name = model_name
+        self.params = parameters
+        self.automl_tool = automl_tool.lower()
+        self.register_model = bool(model_name)
+
+
+class MLFlowProjectsBasicAlgorithm(BaseMLflow):
+    """Task MLflow projects object, declare behavior for BasicAlgorithm task to dolphinscheduler.
+
+    :param name: task name
+    :param data_path: data path of MLflow Project, Support git address and directory on worker.
+    :param algorithm: The selected algorithm currently supports LR, SVM, LightGBM and XGboost
+            based on scikit-learn form.
+    :param mlflow_tracking_uri: MLflow tracking server uri, default is http://127.0.0.1:5000
+    :param experiment_name: MLflow experiment name, default is empty
+    :param model_name: MLflow model name, default is empty
+    :param parameters: MLflow project parameters, default is empty
+    :param search_params: Whether to search the parameters, default is empty
+
+    """
+
+    mlflow_job_type = MLflowJobType.BASIC_ALGORITHM
+    mlflow_task_type = MLflowTaskType.MLFLOW_PROJECTS
+
+    _child_task_mlflow_attr = {
+        "mlflow_job_type",
+        "experiment_name",
+        "model_name",
+        "register_model",
+        "data_path",
+        "params",
+        "algorithm",
+        "search_params",
+    }
+
+    def __init__(
+        self,
+        name: str,
+        data_path: str,
+        algorithm: Optional[str] = "lightgbm",
+        mlflow_tracking_uri: Optional[str] = DEFAULT_MLFLOW_TRACKING_URI,
+        experiment_name: Optional[str] = "",
+        model_name: Optional[str] = "",
+        parameters: Optional[str] = "",
+        search_params: Optional[str] = "",
+        *args,
+        **kwargs
+    ):
+        """Init mlflow projects task."""
+        super().__init__(name, mlflow_tracking_uri, *args, **kwargs)
+        self.data_path = data_path
+        self.experiment_name = experiment_name
+        self.model_name = model_name
+        self.params = parameters
+        self.algorithm = algorithm.lower()
+        self.search_params = search_params
+        self.register_model = bool(model_name)
diff --git a/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py
new file mode 100644
index 0000000000..2159b6c77e
--- /dev/null
+++ b/dolphinscheduler-python/pydolphinscheduler/tests/tasks/test_mlflow.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test Task MLflow."""
+from copy import deepcopy
+from unittest.mock import patch
+
+from pydolphinscheduler.tasks.mlflow import (
+    MLflowDeployType,
+    MLflowJobType,
+    MLflowModels,
+    MLFlowProjectsAutoML,
+    MLFlowProjectsBasicAlgorithm,
+    MLFlowProjectsCustom,
+    MLflowTaskType,
+)
+
+CODE = 123
+VERSION = 1
+MLFLOW_TRACKING_URI = "http://127.0.0.1:5000"
+
+EXPECT = {
+    "code": CODE,
+    "version": VERSION,
+    "description": None,
+    "delayTime": 0,
+    "taskType": "MLFLOW",
+    "taskParams": {
+        "resourceList": [],
+        "localParams": [],
+        "dependence": {},
+        "conditionResult": {"successNode": [""], "failedNode": [""]},
+        "waitStartTimeout": {},
+    },
+    "flag": "YES",
+    "taskPriority": "MEDIUM",
+    "workerGroup": "default",
+    "environmentCode": None,
+    "failRetryTimes": 0,
+    "failRetryInterval": 1,
+    "timeoutFlag": "CLOSE",
+    "timeoutNotifyStrategy": None,
+    "timeout": 0,
+}
+
+
+def test_mlflow_models_get_define():
+    """Test task mlflow models function get_define."""
+    name = "mlflow_models"
+    model_uri = "models:/xgboost_native/Production"
+    port = 7001
+    cpu_limit = 2.0
+    memory_limit = "600M"
+
+    expect = deepcopy(EXPECT)
+    expect["name"] = name
+    task_params = expect["taskParams"]
+    task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
+    task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_MODELS
+    task_params["deployType"] = MLflowDeployType.DOCKER_COMPOSE
+    task_params["deployModelKey"] = model_uri
+    task_params["deployPort"] = port
+    task_params["cpuLimit"] = cpu_limit
+    task_params["memoryLimit"] = memory_limit
+
+    with patch(
+        "pydolphinscheduler.core.task.Task.gen_code_and_version",
+        return_value=(CODE, VERSION),
+    ):
+        task = MLflowModels(
+            name=name,
+            model_uri=model_uri,
+            mlflow_tracking_uri=MLFLOW_TRACKING_URI,
+            deploy_mode=MLflowDeployType.DOCKER_COMPOSE,
+            port=port,
+            cpu_limit=cpu_limit,
+            memory_limit=memory_limit,
+        )
+        assert task.get_define() == expect
+
+
+def test_mlflow_project_custom_get_define():
+    """Test task mlflow project custom function get_define."""
+    name = ("train_xgboost_native",)
+    repository = "https://github.com/mlflow/mlflow#examples/xgboost/xgboost_native"
+    mlflow_tracking_uri = MLFLOW_TRACKING_URI
+    parameters = "-P learning_rate=0.2 -P colsample_bytree=0.8 -P subsample=0.9"
+    experiment_name = "xgboost"
+
+    expect = deepcopy(EXPECT)
+    expect["name"] = name
+    task_params = expect["taskParams"]
+
+    task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
+    task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
+    task_params["mlflowJobType"] = MLflowJobType.CUSTOM_PROJECT
+    task_params["experimentName"] = experiment_name
+    task_params["params"] = parameters
+    task_params["mlflowProjectRepository"] = repository
+    task_params["mlflowProjectVersion"] = "dev"
+
+    with patch(
+        "pydolphinscheduler.core.task.Task.gen_code_and_version",
+        return_value=(CODE, VERSION),
+    ):
+        task = MLFlowProjectsCustom(
+            name=name,
+            repository=repository,
+            mlflow_tracking_uri=mlflow_tracking_uri,
+            parameters=parameters,
+            experiment_name=experiment_name,
+            version="dev",
+        )
+        assert task.get_define() == expect
+
+
+def test_mlflow_project_automl_get_define():
+    """Test task mlflow project automl function get_define."""
+    name = ("train_automl",)
+    mlflow_tracking_uri = MLFLOW_TRACKING_URI
+    parameters = "time_budget=30;estimator_list=['lgbm']"
+    experiment_name = "automl_iris"
+    model_name = "iris_A"
+    automl_tool = "flaml"
+    data_path = "/data/examples/iris"
+
+    expect = deepcopy(EXPECT)
+    expect["name"] = name
+    task_params = expect["taskParams"]
+
+    task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
+    task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
+    task_params["mlflowJobType"] = MLflowJobType.AUTOML
+    task_params["experimentName"] = experiment_name
+    task_params["modelName"] = model_name
+    task_params["registerModel"] = bool(model_name)
+    task_params["dataPath"] = data_path
+    task_params["params"] = parameters
+    task_params["automlTool"] = automl_tool
+
+    with patch(
+        "pydolphinscheduler.core.task.Task.gen_code_and_version",
+        return_value=(CODE, VERSION),
+    ):
+        task = MLFlowProjectsAutoML(
+            name=name,
+            mlflow_tracking_uri=mlflow_tracking_uri,
+            parameters=parameters,
+            experiment_name=experiment_name,
+            model_name=model_name,
+            automl_tool=automl_tool,
+            data_path=data_path,
+        )
+    assert task.get_define() == expect
+
+
+def test_mlflow_project_basic_algorithm_get_define():
+    """Test task mlflow project BasicAlgorithm function get_define."""
+    name = "train_basic_algorithm"
+    mlflow_tracking_uri = MLFLOW_TRACKING_URI
+    parameters = "n_estimators=200;learning_rate=0.2"
+    experiment_name = "basic_algorithm_iris"
+    model_name = "iris_B"
+    algorithm = "lightgbm"
+    data_path = "/data/examples/iris"
+    search_params = "max_depth=[5, 10];n_estimators=[100, 200]"
+
+    expect = deepcopy(EXPECT)
+    expect["name"] = name
+    task_params = expect["taskParams"]
+
+    task_params["mlflowTrackingUri"] = MLFLOW_TRACKING_URI
+    task_params["mlflowTaskType"] = MLflowTaskType.MLFLOW_PROJECTS
+    task_params["mlflowJobType"] = MLflowJobType.BASIC_ALGORITHM
+    task_params["experimentName"] = experiment_name
+    task_params["modelName"] = model_name
+    task_params["registerModel"] = bool(model_name)
+    task_params["dataPath"] = data_path
+    task_params["params"] = parameters
+    task_params["algorithm"] = algorithm
+    task_params["searchParams"] = search_params
+
+    with patch(
+        "pydolphinscheduler.core.task.Task.gen_code_and_version",
+        return_value=(CODE, VERSION),
+    ):
+        task = MLFlowProjectsBasicAlgorithm(
+            name=name,
+            mlflow_tracking_uri=mlflow_tracking_uri,
+            parameters=parameters,
+            experiment_name=experiment_name,
+            model_name=model_name,
+            algorithm=algorithm,
+            data_path=data_path,
+            search_params=search_params,
+        )
+    assert task.get_define() == expect