You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@dolphinscheduler.apache.org by zh...@apache.org on 2022/10/24 07:28:00 UTC

[dolphinscheduler-mlflow] branch main updated: Update the model registration policy (f1-score) (#5)

This is an automated email from the ASF dual-hosted git repository.

zhoujieguang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/dolphinscheduler-mlflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c49ec23  Update the model registration policy (f1-score) (#5)
c49ec23 is described below

commit c49ec23769353f3fb986c475dfaab78ce882491c
Author: JieguangZhou <ji...@163.com>
AuthorDate: Mon Oct 24 15:27:56 2022 +0800

    Update the model registration policy (f1-score) (#5)
---
 Project-AutoML/train.py                          | 37 +++++++++++++++----
 Project-BasicAlgorithm/core/training/lightgbm.py |  1 -
 Project-BasicAlgorithm/core/training/lr.py       |  1 -
 Project-BasicAlgorithm/core/training/svm.py      |  1 -
 Project-BasicAlgorithm/core/training/xgboost.py  |  1 -
 Project-BasicAlgorithm/core/utils.py             |  6 ++-
 Project-BasicAlgorithm/train.py                  | 47 ++++++++++++++++++++----
 7 files changed, 73 insertions(+), 21 deletions(-)

diff --git a/Project-AutoML/train.py b/Project-AutoML/train.py
index 269f496..e55fc7f 100644
--- a/Project-AutoML/train.py
+++ b/Project-AutoML/train.py
@@ -49,7 +49,7 @@ def get_tool(tool_name):
     return Tool
 
 
-def create_model_version(model_name, run_id=None, auto_replace=True):
+def create_model_version(model_name, key_metrics=None, run_id=None, auto_replace=True):
     client = mlflow.tracking.MlflowClient()
     filter_string = "name='{}'".format(model_name)
     versions = client.search_model_versions(filter_string)
@@ -57,18 +57,38 @@ def create_model_version(model_name, run_id=None, auto_replace=True):
     if not versions:
         client.create_registered_model(model_name)
 
-    # TODO: 根据与上一个version对比来判断是否更换Production的模型版本
     for version in versions:
         if version.current_stage == "Production":
             client.transition_model_version_stage(
                 model_name, version=version.version, stage="Archived"
             )
 
-    uri = f"runs:/{run_id}/{ARTIFACT_TAG}"
-    mv = mlflow.register_model(uri, model_name)
-    client.transition_model_version_stage(
-        model_name, version=mv.version, stage="Production"
-    )
+    if run_id:
+        uri = f"runs:/{run_id}/{ARTIFACT_TAG}"
+        mv = mlflow.register_model(uri, model_name)
+
+        if not key_metrics:
+            client.transition_model_version_stage(
+                model_name, version=mv.version, stage="Production"
+            )
+            logger.info("register last version to Production")
+
+    if key_metrics:
+        version2metrics = []
+        versions = client.search_model_versions(filter_string)
+        for version in versions:
+            metrics = client.get_run(version.run_id).data.metrics[key_metrics]
+            version2metrics.append((version.version, metrics))
+
+        logger.info(f"version2metrics({key_metrics}): {version2metrics}")
+
+        best_version = max(version2metrics, key=lambda x: x[1])[0]
+
+        logger.info("register version: %s to Production", best_version)
+        client.transition_model_version_stage(
+            model_name, version=best_version, stage="Production"
+        )
+
     return versions
 
 
@@ -114,7 +134,8 @@ def main(
         code_path=["automl/", "predictor.py"],
     )
     if model_name:
-        create_model_version(model_name, run_id=model_info.run_id)
+        create_model_version(
+            model_name, key_metrics='f1-score', run_id=model_info.run_id)
 
 
 if __name__ == "__main__":
diff --git a/Project-BasicAlgorithm/core/training/lightgbm.py b/Project-BasicAlgorithm/core/training/lightgbm.py
index b0141a0..94cfd94 100644
--- a/Project-BasicAlgorithm/core/training/lightgbm.py
+++ b/Project-BasicAlgorithm/core/training/lightgbm.py
@@ -29,7 +29,6 @@ def train_lightgbm(
     train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
 ):
     pipeline_mods = []
-    mlflow.autolog()
 
     pipeline_mods.append(("oridinal_encoder", get_oridinal_encoder()))
     pipeline = Pipeline(steps=pipeline_mods)
diff --git a/Project-BasicAlgorithm/core/training/lr.py b/Project-BasicAlgorithm/core/training/lr.py
index b5babdf..5b8abfa 100644
--- a/Project-BasicAlgorithm/core/training/lr.py
+++ b/Project-BasicAlgorithm/core/training/lr.py
@@ -29,7 +29,6 @@ def train_lr(
     train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
 ):
     pipeline_mods = []
-    mlflow.autolog()
     pipeline_mods.append(("onehot_encoder", get_onehot_encoder()))
 
     pipeline = Pipeline(steps=pipeline_mods)
diff --git a/Project-BasicAlgorithm/core/training/svm.py b/Project-BasicAlgorithm/core/training/svm.py
index 626a08d..4cf7c82 100644
--- a/Project-BasicAlgorithm/core/training/svm.py
+++ b/Project-BasicAlgorithm/core/training/svm.py
@@ -29,7 +29,6 @@ def train_svc(
     train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
 ):
     pipeline_mods = []
-    mlflow.autolog()
     pipeline_mods.append(("onehot_encoder", get_onehot_encoder()))
 
     pipeline = Pipeline(steps=pipeline_mods)
diff --git a/Project-BasicAlgorithm/core/training/xgboost.py b/Project-BasicAlgorithm/core/training/xgboost.py
index 77db830..69f7f97 100644
--- a/Project-BasicAlgorithm/core/training/xgboost.py
+++ b/Project-BasicAlgorithm/core/training/xgboost.py
@@ -30,7 +30,6 @@ def train_xgboost(
     train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
 ):
     pipeline_mods = []
-    mlflow.autolog()
     pipeline_mods.append(("oridinal_encoder", get_oridinal_encoder()))
     pipeline = Pipeline(steps=pipeline_mods)
 
diff --git a/Project-BasicAlgorithm/core/utils.py b/Project-BasicAlgorithm/core/utils.py
index a240e23..bf013d7 100644
--- a/Project-BasicAlgorithm/core/utils.py
+++ b/Project-BasicAlgorithm/core/utils.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import numpy as np
+import pandas as pd
 from sklearn.model_selection import GridSearchCV
 from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
 
@@ -39,7 +40,10 @@ def train_model(model_cls, params, train_x, train_y):
         optimized_model = GridSearchCV(estimator=model, param_grid=params.search_params)
         optimized_model.fit(train_x, train_y)
         model = optimized_model.best_estimator_
-        print(optimized_model.cv_results_)
+        params = optimized_model.cv_results_['params']
+        mean_test_score = optimized_model.cv_results_['mean_test_score']
+        for param, score in zip(params, mean_test_score):
+            print(param, score)
     else:
         model.fit(train_x, train_y)
     return model
diff --git a/Project-BasicAlgorithm/train.py b/Project-BasicAlgorithm/train.py
index a5b99d8..4202747 100644
--- a/Project-BasicAlgorithm/train.py
+++ b/Project-BasicAlgorithm/train.py
@@ -15,12 +15,22 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import logging
+
 import click
 import mlflow
 import mlflow.sklearn
 
 from core.data import load_data
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="[%(asctime)s] %(name)s %(levelname)s %(message)s",
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+logger = logging.getLogger(__name__)
+
 
 def get_training_func(algorithm):
     if algorithm == "svm":
@@ -41,7 +51,7 @@ def get_training_func(algorithm):
     return training_func
 
 
-def create_model_version(model_name, run_id=None, auto_replace=True):
+def create_model_version(model_name, key_metrics=None, run_id=None, auto_replace=True):
     client = mlflow.tracking.MlflowClient()
     filter_string = "name='{}'".format(model_name)
     versions = client.search_model_versions(filter_string)
@@ -49,18 +59,38 @@ def create_model_version(model_name, run_id=None, auto_replace=True):
     if not versions:
         client.create_registered_model(model_name)
 
-    # TODO: 根据与上一个version对比来判断是否更换Production的模型版本
     for version in versions:
         if version.current_stage == "Production":
             client.transition_model_version_stage(
                 model_name, version=version.version, stage="Archived"
             )
 
-    uri = f"runs:/{run_id}/sklearn_model"
-    mv = mlflow.register_model(uri, model_name)
-    client.transition_model_version_stage(
-        model_name, version=mv.version, stage="Production"
-    )
+    if run_id:
+        uri = f"runs:/{run_id}/sklearn_model"
+        mv = mlflow.register_model(uri, model_name)
+
+        if not key_metrics:
+            client.transition_model_version_stage(
+                model_name, version=mv.version, stage="Production"
+            )
+            logger.info("register last version to Production")
+
+    if key_metrics:
+        version2metrics = []
+        versions = client.search_model_versions(filter_string)
+        for version in versions:
+            metrics = client.get_run(version.run_id).data.metrics[key_metrics]
+            version2metrics.append((version.version, metrics))
+
+        logger.info(f"version2metrics({key_metrics}): {version2metrics}")
+
+        best_version = max(version2metrics, key=lambda x: x[1])[0]
+
+        logger.info("register version: %s to Production", best_version)
+        client.transition_model_version_stage(
+            model_name, version=best_version, stage="Production"
+        )
+
     return versions
 
 
@@ -94,7 +124,8 @@ def main(algorithm, data_path, label_column, model_name, random_state, param_fil
         mlflow.sklearn.log_model(model, artifact_path="sklearn_model")
 
     if model_name:
-        create_model_version(model_name, run_id=run.info.run_id)
+        create_model_version(
+            model_name, key_metrics='f1-score', run_id=run.info.run_id)
 
 
 if __name__ == "__main__":