You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@dolphinscheduler.apache.org by zh...@apache.org on 2022/05/17 06:49:18 UTC
[dolphinscheduler-mlflow] branch main updated: Add BasicAlgorithm and AutoML (#1)
This is an automated email from the ASF dual-hosted git repository.
zhongjiajie pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/dolphinscheduler-mlflow.git
The following commit(s) were added to refs/heads/main by this push:
new d4f0f8b Add BasicAlgorithm and AutoML (#1)
d4f0f8b is described below
commit d4f0f8bcecf3114da06f9f1949908c43bde4cd1d
Author: JieguangZhou <ji...@163.com>
AuthorDate: Tue May 17 14:49:14 2022 +0800
Add BasicAlgorithm and AutoML (#1)
* Add BasicAlgorithm and AutoML
* Add LICENSE
---
.gitignore | 129 +++++++++++++++
LICENSE | 201 +++++++++++++++++++++++
Project-AutoML/MLproject | 19 +++
Project-AutoML/README.md | 1 +
Project-AutoML/automl/__init__.py | 8 +
Project-AutoML/automl/data.py | 56 +++++++
Project-AutoML/automl/metrics.py | 10 ++
Project-AutoML/automl/mod/__init__.py | 0
Project-AutoML/automl/mod/mod_autosklearn.py | 108 ++++++++++++
Project-AutoML/automl/mod/mod_flaml.py | 87 ++++++++++
Project-AutoML/automl/mod/tool.py | 43 +++++
Project-AutoML/automl/params.py | 62 +++++++
Project-AutoML/conda.yaml | 15 ++
Project-AutoML/predictor.py | 30 ++++
Project-AutoML/train.py | 104 ++++++++++++
Project-BasicAlgorithm/MLproject | 23 +++
Project-BasicAlgorithm/README.md | 3 +
Project-BasicAlgorithm/conda.yaml | 18 ++
Project-BasicAlgorithm/core/__init__.py | 0
Project-BasicAlgorithm/core/data.py | 53 ++++++
Project-BasicAlgorithm/core/metrics.py | 8 +
Project-BasicAlgorithm/core/training/__init__.py | 0
Project-BasicAlgorithm/core/training/lightgbm.py | 35 ++++
Project-BasicAlgorithm/core/training/lr.py | 35 ++++
Project-BasicAlgorithm/core/training/params.py | 157 ++++++++++++++++++
Project-BasicAlgorithm/core/training/svm.py | 32 ++++
Project-BasicAlgorithm/core/training/xgboost.py | 36 ++++
Project-BasicAlgorithm/core/utils.py | 28 ++++
Project-BasicAlgorithm/train.py | 84 ++++++++++
29 files changed, 1385 insertions(+)
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b6e4761
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,129 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..261eeb9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Project-AutoML/MLproject b/Project-AutoML/MLproject
new file mode 100644
index 0000000..571c8c2
--- /dev/null
+++ b/Project-AutoML/MLproject
@@ -0,0 +1,19 @@
+name: MLflow-AutoML
+
+conda_env: conda.yaml
+
+entry_points:
+ main:
+ parameters:
+ tool: {type: str, default: autosklearn}
+ data_path: str
+ label_column: {type: str, default: label}
+ model_name: str
+ params: {type: str, default: ""}
+ command: "python train.py \
+ --tool {tool} \
+ --data_path {data_path} \
+ --label_column {label_column} \
+ --model_name {model_name} \
+ --params {params} "
+
diff --git a/Project-AutoML/README.md b/Project-AutoML/README.md
new file mode 100644
index 0000000..491d732
--- /dev/null
+++ b/Project-AutoML/README.md
@@ -0,0 +1 @@
+# MLflow-AutoML
\ No newline at end of file
diff --git a/Project-AutoML/automl/__init__.py b/Project-AutoML/automl/__init__.py
new file mode 100644
index 0000000..ec115d8
--- /dev/null
+++ b/Project-AutoML/automl/__init__.py
@@ -0,0 +1,8 @@
+import logging
+import warnings
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
diff --git a/Project-AutoML/automl/data.py b/Project-AutoML/automl/data.py
new file mode 100644
index 0000000..e38c153
--- /dev/null
+++ b/Project-AutoML/automl/data.py
@@ -0,0 +1,56 @@
+import os
+from logging import getLogger
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+
+logger = getLogger(__name__)
+
+PATH_ERROR_MESSAGE = (
+ "data_path only support csv data or directory contained train.csv and test.csv"
+)
+
+
+def load_data(data_path, label_column, test_size=0.25, random_state=1):
+ if os.path.isdir(data_path):
+ train_path = os.path.join(data_path, "train.csv")
+ test_path = os.path.join(data_path, "test.csv")
+ assert os.path.exists(train_path) and os.path.exists(
+ test_path
+ ), PATH_ERROR_MESSAGE
+
+ logger.info(f"load train data from {train_path}")
+ logger.info(f"load test data from {test_path}")
+ train_x, train_y = load_csv_data(train_path, label_column)
+ test_x, test_y = load_csv_data(test_path, label_column)
+
+ elif data_path.endswith(".csv"):
+ logger.info(f"load data from {data_path}")
+ logger.info("split data to train set and test set")
+ train_x, train_y, test_x, test_y = load_split_csv_data(
+ data_path, label_column, test_size=test_size, random_state=random_state
+ )
+
+ else:
+ raise Exception(PATH_ERROR_MESSAGE)
+
+ return train_x, train_y, test_x, test_y
+
+
+def load_split_csv_data(data_path, label_column, test_size=0.25, random_state=1):
+
+ data = pd.read_csv(data_path)
+ train, test = train_test_split(data, test_size=test_size, random_state=random_state)
+ train_x = train.drop([label_column], axis=1)
+ test_x = test.drop([label_column], axis=1)
+ train_y = train[[label_column]]
+ test_y = test[[label_column]]
+ return train_x, train_y, test_x, test_y
+
+
+def load_csv_data(data_path, label_column):
+
+ data = pd.read_csv(data_path)
+ x = data.drop([label_column], axis=1)
+ y = data[[label_column]]
+ return x, y
diff --git a/Project-AutoML/automl/metrics.py b/Project-AutoML/automl/metrics.py
new file mode 100644
index 0000000..b032a00
--- /dev/null
+++ b/Project-AutoML/automl/metrics.py
@@ -0,0 +1,10 @@
+import pickle
+
+from sklearn.metrics import classification_report
+
+
+def eval_classification_metrics(y_true, y_pred):
+ result: dict = classification_report(y_true, y_pred, output_dict=True)
+ metrics = result["weighted avg"]
+ metrics["accuracy"] = result["accuracy"]
+ return metrics
diff --git a/Project-AutoML/automl/mod/__init__.py b/Project-AutoML/automl/mod/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Project-AutoML/automl/mod/mod_autosklearn.py b/Project-AutoML/automl/mod/mod_autosklearn.py
new file mode 100644
index 0000000..a523405
--- /dev/null
+++ b/Project-AutoML/automl/mod/mod_autosklearn.py
@@ -0,0 +1,108 @@
+import pickle
+
+import numpy as np
+from autosklearn.classification import AutoSklearnClassifier
+from sklearn.pipeline import Pipeline
+from sklearn.preprocessing import OrdinalEncoder
+
+from automl.metrics import eval_classification_metrics
+from automl.mod.tool import BasePredictor, Tool
+from automl.params import Params
+
+
+class AutoSklearn(Tool):
+ model_path = "autosklearn.pkl"
+
+ conda_env = {
+ "channels": ["defaults", "conda-forge"],
+ "dependencies": [
+ "python=3.8.2",
+ {
+ "pip": [
+ "mlflow",
+ "scikit-learn==0.24.2",
+ "boto3==1.22.2",
+ "pandas==1.3.5",
+ "setuptools<59.6.0",
+ "auto-sklearn==0.14.6",
+ ],
+ },
+ ],
+ "name": "mlflow-env",
+ }
+
+ @staticmethod
+ def train_automl(train_x, train_y, other_params=None, **kwargs):
+ params = Params(param_str=other_params, **kwargs)
+ print(params)
+ pipeline_mods = []
+
+ pipeline_mods.append(
+ (
+ "oridinal_encoder",
+ OrdinalEncoder(
+ unknown_value=np.nan, handle_unknown="use_encoded_value"
+ ),
+ )
+ )
+ pipeline = Pipeline(steps=pipeline_mods)
+ feat_type = [
+ "Categorical" if x.name in {"object", "category"} else "Numerical"
+ for x in train_x.dtypes
+ ]
+ train_x = pipeline.fit_transform(train_x)
+ classifier = AutoSklearnClassifier(**params.input_params)
+ classifier.fit(train_x, train_y, feat_type=feat_type)
+
+ pipeline.steps.append(("classifier", classifier))
+ return pipeline
+
+ @staticmethod
+ def eval(pipeline: Pipeline, test_x, test_y, task="classification"):
+ oridinal_encoder = pipeline.steps[0][1]
+ classifier = pipeline.steps[1][1]
+ test_x = oridinal_encoder.transform(test_x)
+ y_pred = classifier.predict(test_x)
+ if task == "classification":
+ metrics = eval_classification_metrics(test_y, y_pred)
+ else:
+ metrics = super().eval_automl(automl, test_x, test_y)
+
+ return metrics
+
+ @staticmethod
+ def save_automl(classifier: AutoSklearnClassifier, save_path: str):
+ with open(save_path, "wb") as w_f:
+ pickle.dump(classifier, w_f)
+
+
+class Predictor(BasePredictor):
+ def load_automl(self, model_path):
+ with open(model_path, "rb") as r_f:
+ self.pipeline: AutoSklearnClassifier = pickle.load(r_f)
+ self.oridinal_encoder = self.pipeline.steps[0][1]
+ self.automl = self.pipeline.steps[1][1]
+
+ def predict(self, inputs):
+ if isinstance(self.automl, AutoSklearnClassifier):
+ result = self.predict_classification(inputs)
+ else:
+ result = self.automl.predict(inputs)
+ return result
+
+ def predict_classification(self, inputs):
+
+ inputs = self.oridinal_encoder.transform(inputs)
+
+ pred_proba = self.classifier.predict_proba(inputs)
+ label_indexes = pred_proba.argmax(axis=1)
+ probs = pred_proba[np.arange(pred_proba.shape[0]), label_indexes]
+ labels = (
+ self.classifier.automl_.InputValidator.target_validator.inverse_transform(
+ label_indexes
+ )
+ )
+ result = []
+ for label, pro in zip(labels, probs):
+ result.append({"label": label, "confidence": float(pro)})
+ return result
diff --git a/Project-AutoML/automl/mod/mod_flaml.py b/Project-AutoML/automl/mod/mod_flaml.py
new file mode 100644
index 0000000..819f878
--- /dev/null
+++ b/Project-AutoML/automl/mod/mod_flaml.py
@@ -0,0 +1,87 @@
+import pickle
+
+import numpy as np
+import pandas as pd
+from flaml import AutoML
+
+from automl.metrics import eval_classification_metrics
+from automl.mod.tool import BasePredictor, Tool
+from automl.params import Params
+
+
+def convert_y(y):
+ if isinstance(y, pd.DataFrame):
+ y = y.to_numpy().reshape(-1)
+ return y
+
+
+class FLAML(Tool):
+ model_path = "flaml.pkl"
+
+ conda_env = {
+ "channels": ["defaults", "conda-forge"],
+ "dependencies": [
+ "python=3.8.2",
+ {
+ "pip": [
+ "mlflow",
+ "scikit-learn==0.24.2",
+ "boto3==1.22.2",
+ "pandas==1.3.5",
+ "setuptools<59.6.0",
+ "flaml==1.0.1",
+ ],
+ },
+ ],
+ "name": "mlflow-env",
+ }
+
+ @staticmethod
+ def train_automl(train_x, train_y, other_params=None, **kwargs):
+ params = Params(param_str=other_params, **kwargs)
+ automl = AutoML(**params.input_params)
+ automl.predict
+ train_y = convert_y(train_y)
+ automl.fit(train_x, train_y)
+
+ return automl
+
+ @staticmethod
+ def eval_automl(automl: AutoML, test_x, test_y, task="classification"):
+ y_pred = automl.predict(test_x)
+
+ test_y = convert_y(test_y)
+ if task == "classification":
+ metrics = eval_classification_metrics(test_y, y_pred)
+ else:
+ metrics = Tool.eval_automl(automl, test_x, test_y)
+ return metrics
+
+ @staticmethod
+ def save_automl(automl: AutoML, save_path: str):
+ automl.pickle(save_path)
+
+
+class Predictor(BasePredictor):
+ def load_automl(self, model_path):
+ with open(model_path, "rb") as r_f:
+ self.automl: AutoML = pickle.load(r_f)
+
+ def predict(self, inputs):
+ if self.automl._settings.get("task") == "classification":
+ result = self.predict_classification(inputs)
+ else:
+ result = self.automl.predict(inputs)
+ return result
+
+ def predict_classification(self, inputs):
+ pred_proba = self.automl.predict_proba(inputs)
+ label_indexes = pred_proba.argmax(axis=1)
+ probs = pred_proba[np.arange(pred_proba.shape[0]), label_indexes]
+ labels = self.automl._label_transformer.inverse_transform(
+ pd.Series(label_indexes.astype(int))
+ )
+ result = []
+ for label, pro in zip(labels, probs):
+ result.append({"label": label, "confidence": float(pro)})
+ return result
diff --git a/Project-AutoML/automl/mod/tool.py b/Project-AutoML/automl/mod/tool.py
new file mode 100644
index 0000000..6248483
--- /dev/null
+++ b/Project-AutoML/automl/mod/tool.py
@@ -0,0 +1,43 @@
+class Tool:
+ model_path = None
+
+ conda_env = {
+ "channels": ["defaults", "conda-forge"],
+ "dependencies": [
+ "python=3.8.2",
+ {
+ "pip": [
+ "mlflow",
+ "scikit-learn==0.24.2",
+ "boto3==1.22.2",
+ "pandas==1.3.5",
+ "setuptools<59.6.0",
+ ],
+ },
+ ],
+ "name": "mlflow-env",
+ }
+
+ @staticmethod
+ def train_automl(train_x, train_y, other_params=None, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def eval_automl(automl, test_x, test_y):
+ score = automl.score(test_x, test_y)
+ return {"score": score}
+
+ @staticmethod
+ def save_automl(automl, save_path: str):
+ raise NotImplementedError
+
+
+class BasePredictor:
+ def __init__(self, automl_path=None):
+ self.load_automl(automl_path)
+
+ def predict(self, inputs):
+ return {}
+
+ def load_automl(self, path):
+ ...
diff --git a/Project-AutoML/automl/params.py b/Project-AutoML/automl/params.py
new file mode 100644
index 0000000..829dcdb
--- /dev/null
+++ b/Project-AutoML/automl/params.py
@@ -0,0 +1,62 @@
+import json
+import warnings
+from logging import getLogger
+
+logger = getLogger(__name__)
+
+
+class Params:
+ def __init__(self, param_file=None, param_str=None, **kwargs):
+ input_params = self.parse_file(param_file)
+ str_input_params = self.parse_param_str(param_str)
+ input_params.update(str_input_params)
+ input_params.update(kwargs)
+
+ self.input_params = self.check_input_params(input_params)
+ logger.info(f"params : {self}")
+
+ def check_input_params(self, input_params):
+ # check and adaptive parameter type for input_params
+ for key, value in input_params.items():
+ try:
+ value = eval(value, {}, {})
+ except Exception as e:
+ value = value
+ continue
+ input_params[key] = value
+ return input_params
+
+ @staticmethod
+ def parse_file(path):
+ path = path or ""
+ if not path.strip():
+ return {}
+ with open(path, "r") as r_f:
+ params = json.load(r_f)
+ return params
+
+ @staticmethod
+ def parse_param_str(param_str):
+ param_str = param_str or ""
+ if not param_str.strip():
+ return {}
+ name_value_pairs = param_str.split(";")
+ pairs = []
+ for name_value_pair in name_value_pairs:
+ if not name_value_pair.strip():
+ continue
+ k_v = name_value_pair.split("=")
+ if len(k_v) != 2:
+ warnings.warn(f"{name_value_pair} error, will be ignore")
+ continue
+ key, value = name_value_pair.split("=")
+ pairs.append((key.strip(), value.strip()))
+ params = dict(pairs)
+ return params
+
+ def __str__(self):
+ input_params_message = str(self.input_params)
+ message = f"input_params: {input_params_message}"
+ return message
+
+ __repr__ = __str__
diff --git a/Project-AutoML/conda.yaml b/Project-AutoML/conda.yaml
new file mode 100644
index 0000000..f74bba4
--- /dev/null
+++ b/Project-AutoML/conda.yaml
@@ -0,0 +1,15 @@
+name: MLflow-AutoML
+channels:
+ - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/
+dependencies:
+ - python=3.8.2
+ - pip
+ - pip:
+ - mlflow
+ - click==8.0.3
+ - scikit-learn==0.24.2
+ - boto3==1.22.2
+ - pandas>=1.0.0
+ - setuptools<59.6.0
+ - auto-sklearn==0.14.6
+ - flaml==1.0.1
diff --git a/Project-AutoML/predictor.py b/Project-AutoML/predictor.py
new file mode 100644
index 0000000..2351e4f
--- /dev/null
+++ b/Project-AutoML/predictor.py
@@ -0,0 +1,30 @@
+from mlflow.pyfunc import PythonModel
+
+
+class PredictorWrapper(PythonModel):
+ def load_context(self, context):
+ model_path = context.artifacts["model_path"]
+ if "autosklearn" in model_path:
+ self.predictor = self.load_autosklearn_predictor(model_path)
+
+ elif "flaml" in model_path:
+ self.predictor = self.load_flaml_predictor(model_path)
+
+ else:
+ assert f"cant not load model from path {model_path}"
+
+ def predict(self, context, model_input):
+ results = self.predictor.predict(model_input)
+ return {"results": results}
+
+ def load_autosklearn_predictor(self, path):
+ from automl.mod.mod_autosklearn import Predictor
+
+ predictor = Predictor(path)
+ return predictor
+
+ def load_flaml_predictor(self, path):
+ from automl.mod.mod_flaml import Predictor
+
+ predictor = Predictor(path)
+ return predictor
diff --git a/Project-AutoML/train.py b/Project-AutoML/train.py
new file mode 100644
index 0000000..10d13a4
--- /dev/null
+++ b/Project-AutoML/train.py
@@ -0,0 +1,104 @@
+import logging
+
+import click
+import mlflow
+import mlflow.sklearn
+
+from automl.data import load_data
+
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="[%(asctime)s] %(name)s %(levelname)s %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+
+logger = logging.getLogger(__name__)
+
+ARTIFACT_TAG = "artifact"
+
+
+def get_tool(tool_name):
+ tool_name = tool_name.lower()
+ assert tool_name in {"autosklearn", "flaml"}
+ if tool_name.lower() == "autosklearn":
+ from automl.mod.mod_autosklearn import AutoSklearn as Tool
+
+ elif tool_name.lower() == "flaml":
+ from automl.mod.mod_flaml import FLAML as Tool
+
+ else:
+ raise Exception(f"Does not support {tool_name}")
+ return Tool
+
+
+def create_model_version(model_name, run_id=None, auto_replace=True):
+ client = mlflow.tracking.MlflowClient()
+ filter_string = "name='{}'".format(model_name)
+ versions = client.search_model_versions(filter_string)
+
+ if not versions:
+ client.create_registered_model(model_name)
+
+ # TODO: 根据与上一个version对比来判断是否更换Production的模型版本
+ for version in versions:
+ if version.current_stage == "Production":
+ client.transition_model_version_stage(
+ model_name, version=version.version, stage="Archived"
+ )
+
+ uri = f"runs:/{run_id}/{ARTIFACT_TAG}"
+ mv = mlflow.register_model(uri, model_name)
+ client.transition_model_version_stage(
+ model_name, version=mv.version, stage="Production"
+ )
+ return versions
+
+
+@click.command()
+@click.option("--tool")
+@click.option("--data_path")
+@click.option("--label_column", default="label")
+@click.option("--model_name", default=None)
+@click.option("--random_state", default=0)
+@click.option("--params", default=None)
+def main(
+ tool,
+ data_path,
+ label_column,
+ model_name,
+ random_state,
+ params,
+):
+
+ Tool = get_tool(tool)
+
+ train_x, train_y, test_x, test_y = load_data(
+ data_path, label_column, random_state=random_state
+ )
+
+ automl = Tool.train_automl(train_x, train_y, other_params=params)
+
+ metrics = Tool.eval_automl(automl, test_x, test_y)
+ logger.info(f"metrics: {metrics}")
+ mlflow.log_metrics(metrics)
+
+ Tool.save_automl(automl, Tool.model_path)
+
+ from predictor import PredictorWrapper
+
+ artifacts = {"model_path": Tool.model_path}
+
+ model_info = mlflow.pyfunc.log_model(
+ artifact_path=ARTIFACT_TAG,
+ python_model=PredictorWrapper(),
+ artifacts=artifacts,
+ conda_env=Tool.conda_env,
+ code_path=["automl/", "predictor.py"],
+ )
+ if model_name:
+ create_model_version(model_name, run_id=model_info.run_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Project-BasicAlgorithm/MLproject b/Project-BasicAlgorithm/MLproject
new file mode 100644
index 0000000..3c7c946
--- /dev/null
+++ b/Project-BasicAlgorithm/MLproject
@@ -0,0 +1,23 @@
+name: sklearn
+
+conda_env: conda.yaml
+
+entry_points:
+ main:
+ parameters:
+ algorithm: {type: str, default: lightgbm}
+ data_path: str
+ label_column: {type: str, default: label}
+ model_name: str
+ param_file: {type: str, default: ""}
+ params: {type: str, default: ""}
+ search_params: {type: str, default: ""}
+ command: "python train.py \
+ --algorithm {algorithm} \
+ --data_path {data_path} \
+ --label_column {label_column} \
+ --model_name {model_name} \
+ --param_file {param_file} \
+ --params {params} \
+ --search_params {search_params}"
+
diff --git a/Project-BasicAlgorithm/README.md b/Project-BasicAlgorithm/README.md
new file mode 100644
index 0000000..27eaffd
--- /dev/null
+++ b/Project-BasicAlgorithm/README.md
@@ -0,0 +1,3 @@
+# mlflow_sklearn_gallery
+
+Example for building MLops Infra using DolphinScheduler
diff --git a/Project-BasicAlgorithm/conda.yaml b/Project-BasicAlgorithm/conda.yaml
new file mode 100644
index 0000000..1764a19
--- /dev/null
+++ b/Project-BasicAlgorithm/conda.yaml
@@ -0,0 +1,18 @@
+name: sklearn
+channels:
+ - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/
+dependencies:
+
+ - numpy>=1.14.3
+ - pandas>=1.0.0
+ - scikit-learn=1.0.2
+ - pip
+ - pip:
+ - mlflow
+ - click
+ - lightgbm
+ - Pillow
+ - xgboost==1.5.2
+ - boto3
+
+
diff --git a/Project-BasicAlgorithm/core/__init__.py b/Project-BasicAlgorithm/core/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Project-BasicAlgorithm/core/data.py b/Project-BasicAlgorithm/core/data.py
new file mode 100644
index 0000000..6e59963
--- /dev/null
+++ b/Project-BasicAlgorithm/core/data.py
@@ -0,0 +1,53 @@
+import os
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+
+PATH_ERROR_MESSAGE = (
+ "data_path only support csv data or directory contained train.csv and test.csv"
+)
+
+
+def load_data(data_path, label_column, test_size=0.25, random_state=1):
+ if os.path.isdir(data_path):
+ train_path = os.path.join(data_path, "train.csv")
+ test_path = os.path.join(data_path, "test.csv")
+ assert os.path.exists(train_path) and os.path.exists(
+ test_path
+ ), PATH_ERROR_MESSAGE
+
+ print(f"load train data from {train_path}")
+ print(f"load test data from {test_path}")
+ train_x, train_y = load_csv_data(train_path, label_column)
+ test_x, test_y = load_csv_data(test_path, label_column)
+
+ elif data_path.endswith(".csv"):
+ print(f"load data from {data_path}")
+ print("split data to train set and test set")
+ train_x, train_y, test_x, test_y = load_split_csv_data(
+ data_path, label_column, test_size=test_size, random_state=random_state
+ )
+
+ else:
+ raise Exception(PATH_ERROR_MESSAGE)
+
+ return train_x, train_y, test_x, test_y
+
+
+def load_split_csv_data(data_path, label_column, test_size=0.25, random_state=1):
+
+ data = pd.read_csv(data_path)
+ train, test = train_test_split(data, test_size=test_size, random_state=random_state)
+ train_x = train.drop([label_column], axis=1)
+ test_x = test.drop([label_column], axis=1)
+ train_y = train[[label_column]]
+ test_y = test[[label_column]]
+ return train_x, train_y, test_x, test_y
+
+
+def load_csv_data(data_path, label_column):
+
+ data = pd.read_csv(data_path)
+ x = data.drop([label_column], axis=1)
+ y = data[[label_column]]
+ return x, y
diff --git a/Project-BasicAlgorithm/core/metrics.py b/Project-BasicAlgorithm/core/metrics.py
new file mode 100644
index 0000000..37a442c
--- /dev/null
+++ b/Project-BasicAlgorithm/core/metrics.py
@@ -0,0 +1,8 @@
+from sklearn.metrics import classification_report
+
+
+def eval_classification_metrics(y_true, y_pred):
+ result: dict = classification_report(y_true, y_pred, output_dict=True)
+ metrics = result["weighted avg"]
+ metrics["accuracy"] = result["accuracy"]
+ return metrics
diff --git a/Project-BasicAlgorithm/core/training/__init__.py b/Project-BasicAlgorithm/core/training/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/Project-BasicAlgorithm/core/training/lightgbm.py b/Project-BasicAlgorithm/core/training/lightgbm.py
new file mode 100644
index 0000000..3c24acd
--- /dev/null
+++ b/Project-BasicAlgorithm/core/training/lightgbm.py
@@ -0,0 +1,35 @@
+import mlflow
+from lightgbm import LGBMClassifier
+from sklearn.pipeline import Pipeline
+
+from core.metrics import eval_classification_metrics
+from core.utils import get_oridinal_encoder, train_model
+
+from .params import LightGBMParams
+
+
+def train_lightgbm(
+ train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
+):
+ pipeline_mods = []
+ mlflow.autolog()
+
+ pipeline_mods.append(("oridinal_encoder", get_oridinal_encoder()))
+ pipeline = Pipeline(steps=pipeline_mods)
+ train_x = pipeline.fit_transform(train_x)
+
+ params = LightGBMParams(
+ LGBMClassifier,
+ param_file=param_file,
+ param_str=params,
+ search_params=search_params,
+ )
+
+ model = train_model(LGBMClassifier, params, train_x, train_y)
+
+ pipeline.steps.append(("model", model))
+
+ y_pred = pipeline.predict(test_x)
+
+ metrics = eval_classification_metrics(test_y, y_pred)
+ return pipeline, metrics
diff --git a/Project-BasicAlgorithm/core/training/lr.py b/Project-BasicAlgorithm/core/training/lr.py
new file mode 100644
index 0000000..cd70b8d
--- /dev/null
+++ b/Project-BasicAlgorithm/core/training/lr.py
@@ -0,0 +1,35 @@
+import mlflow
+from sklearn.linear_model import LogisticRegression
+from sklearn.pipeline import Pipeline
+
+from core.metrics import eval_classification_metrics
+from core.utils import get_onehot_encoder, train_model
+
+from .params import LrParams
+
+
+def train_lr(
+ train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
+):
+ pipeline_mods = []
+ mlflow.autolog()
+ pipeline_mods.append(("onehot_encoder", get_onehot_encoder()))
+
+ pipeline = Pipeline(steps=pipeline_mods)
+ train_x = pipeline.fit_transform(train_x)
+
+ params = LrParams(
+ LogisticRegression,
+ param_file=param_file,
+ param_str=params,
+ search_params=search_params,
+ )
+
+ model = train_model(LogisticRegression, params, train_x, train_y)
+
+ pipeline.steps.append(("model", model))
+
+ y_pred = pipeline.predict(test_x)
+
+ metrics = eval_classification_metrics(test_y, y_pred)
+ return pipeline, metrics
diff --git a/Project-BasicAlgorithm/core/training/params.py b/Project-BasicAlgorithm/core/training/params.py
new file mode 100644
index 0000000..a8cbb5a
--- /dev/null
+++ b/Project-BasicAlgorithm/core/training/params.py
@@ -0,0 +1,157 @@
+import inspect
+import json
+import warnings
+from copy import deepcopy
+
+
+class Params:
+ def __init__(
+ self, cls, param_file=None, param_str=None, search_params=None, **kwargs
+ ):
+ input_params = self.parse_file(param_file)
+ str_input_params = self.parse_param_str(param_str)
+ input_params.update(str_input_params)
+ input_params.update(kwargs)
+
+ self.input_params = self.check_input_params(cls, input_params)
+ self.search_params = self.check_search_params(cls, search_params)
+
+ def check_input_params(self, cls, input_params):
+ # check and adaptive parameter type for input_params
+ default_params = self.load_cls_default_params(cls)
+ for key, value in input_params.items():
+ parse_func = getattr(self, key, None)
+ if parse_func:
+ value = parse_func(value)
+ elif key in default_params:
+ value = self.match_type(default_params[key], value)
+ input_params[key] = value
+ return input_params
+
+ def check_search_params(self, cls, search_params):
+ # check and adaptive parameter type for search_params
+ search_params = self.parse_param_str(search_params)
+ default_params = self.load_cls_default_params(cls)
+ for key, values in search_params.items():
+ try:
+ values = eval(values)
+ except Exception as _:
+ warnings.warn(f"value : {values} error, is must be list of something")
+ continue
+ parse_func = getattr(self, key, None)
+
+ new_values = []
+ for value in values:
+ if parse_func:
+ value = parse_func(value)
+ elif key in default_params:
+ value = self.match_type(default_params[key], value)
+
+ new_values.append(value)
+ search_params[key] = new_values
+ return search_params
+
+ @staticmethod
+ def load_cls_default_params(cls):
+ default_values = deepcopy(cls.__init__.__defaults__)
+ var_names = cls.__init__.__code__.co_varnames
+ var_names = [name for name in var_names if name not in {"self", "kwargs"}]
+ assert len(default_values) == len(var_names)
+ return dict(zip(var_names, default_values))
+
+ @staticmethod
+ def match_type(refer_var, input_var):
+
+ if isinstance(refer_var, bool):
+ tag = str(input_var).lower()
+ if tag not in {"true", "false"}:
+ warnings.warn(f"value : {input_var} error, set to False")
+
+ input_var = tag == "true"
+
+ elif refer_var is not None:
+ refer_type = type(refer_var)
+ input_var = refer_type(input_var)
+ return input_var
+
+ @staticmethod
+ def parse_file(path):
+ path = path or ""
+ if not path.strip():
+ return {}
+ with open(path, "r") as r_f:
+ params = json.load(r_f)
+ return params
+
+ @staticmethod
+ def parse_param_str(param_str):
+ param_str = param_str or ""
+ if not param_str.strip():
+ return {}
+ name_value_pairs = param_str.split(";")
+ pairs = []
+ for name_value_pair in name_value_pairs:
+ if not name_value_pair.strip():
+ continue
+ k_v = name_value_pair.split("=")
+ if len(k_v) != 2:
+ warnings.warn(f"{name_value_pair} error, will be ignore")
+ continue
+ key, value = name_value_pair.split("=")
+ pairs.append((key.strip(), value.strip()))
+ params = dict(pairs)
+ return params
+
+ @staticmethod
+ def class_weight(value):
+ try:
+ value = eval(value)
+ except:
+ value = value
+ return value
+
+ def __str__(self):
+ input_params_message = str(self.input_params)
+ search_params_message = str(self.search_params)
+ message = f"input_params: {input_params_message}\nsearch_params: {search_params_message}"
+ return message
+
+ __repr__ = __str__
+
+
+class LightGBMParams(Params):
+ ...
+
+
+class XGBoostParams(Params):
+ @staticmethod
+ def load_cls_default_params(cls):
+ from xgboost import XGBModel
+
+ params = Params.load_cls_default_params(XGBModel)
+
+ params["objective"] = "binary:logistic"
+ params["use_label_encoder"] = True
+ return params
+
+
+class LrParams(Params):
+ @staticmethod
+ def load_cls_default_params(cls):
+ from sklearn.linear_model import LogisticRegression
+
+ params = deepcopy(inspect.getfullargspec(LogisticRegression).kwonlydefaults)
+
+ params["penalty"] = "l2"
+
+ return params
+
+
+class SVMParams(Params):
+ @staticmethod
+ def load_cls_default_params(cls):
+ from sklearn.svm import SVC
+
+ params = deepcopy(inspect.getfullargspec(SVC).kwonlydefaults)
+
+ return params
diff --git a/Project-BasicAlgorithm/core/training/svm.py b/Project-BasicAlgorithm/core/training/svm.py
new file mode 100644
index 0000000..8a4aafa
--- /dev/null
+++ b/Project-BasicAlgorithm/core/training/svm.py
@@ -0,0 +1,32 @@
+import mlflow
+from sklearn.pipeline import Pipeline
+from sklearn.svm import SVC
+
+from core.metrics import eval_classification_metrics
+from core.utils import get_onehot_encoder, train_model
+
+from .params import SVMParams
+
+
+def train_svc(
+ train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
+):
+ pipeline_mods = []
+ mlflow.autolog()
+ pipeline_mods.append(("onehot_encoder", get_onehot_encoder()))
+
+ pipeline = Pipeline(steps=pipeline_mods)
+ train_x = pipeline.fit_transform(train_x)
+
+ params = SVMParams(
+ SVC, param_file=param_file, param_str=params, search_params=search_params
+ )
+
+ model = train_model(SVC, params, train_x, train_y)
+
+ pipeline.steps.append(("model", model))
+
+ y_pred = pipeline.predict(test_x)
+
+ metrics = eval_classification_metrics(test_y, y_pred)
+ return pipeline, metrics
diff --git a/Project-BasicAlgorithm/core/training/xgboost.py b/Project-BasicAlgorithm/core/training/xgboost.py
new file mode 100644
index 0000000..b8f9cd7
--- /dev/null
+++ b/Project-BasicAlgorithm/core/training/xgboost.py
@@ -0,0 +1,36 @@
+import mlflow
+from sklearn.model_selection import GridSearchCV
+from sklearn.pipeline import Pipeline
+from xgboost import XGBClassifier
+
+from core.metrics import eval_classification_metrics
+from core.utils import get_oridinal_encoder, train_model
+
+from .params import XGBoostParams
+
+
+def train_xgboost(
+ train_x, train_y, test_x, test_y, param_file=None, params=None, search_params=None
+):
+ pipeline_mods = []
+ mlflow.autolog()
+ pipeline_mods.append(("oridinal_encoder", get_oridinal_encoder()))
+ pipeline = Pipeline(steps=pipeline_mods)
+
+ train_x = pipeline.fit_transform(train_x)
+
+ params = XGBoostParams(
+ XGBClassifier,
+ param_file=param_file,
+ param_str=params,
+ use_label_encoder=True,
+ search_params=search_params,
+ )
+
+ model = train_model(XGBClassifier, params, train_x, train_y)
+
+ pipeline.steps.append(("model", model))
+ y_pred = pipeline.predict(test_x)
+
+ metrics = eval_classification_metrics(test_y, y_pred)
+ return pipeline, metrics
diff --git a/Project-BasicAlgorithm/core/utils.py b/Project-BasicAlgorithm/core/utils.py
new file mode 100644
index 0000000..cf1e935
--- /dev/null
+++ b/Project-BasicAlgorithm/core/utils.py
@@ -0,0 +1,28 @@
+import numpy as np
+from sklearn.model_selection import GridSearchCV
+from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
+
+
+def get_onehot_encoder(sparse=False, handle_unknown="ignore"):
+ return OneHotEncoder(sparse=sparse, handle_unknown=handle_unknown)
+
+
+def get_oridinal_encoder(unknown_value=np.nan, handle_unknown="use_encoded_value"):
+ return OrdinalEncoder(unknown_value=unknown_value, handle_unknown=handle_unknown)
+
+
+def train_model(model_cls, params, train_x, train_y):
+ """
+ train model directly, or train model with searching params
+ """
+
+ model = model_cls(**params.input_params)
+
+ if params.search_params:
+ optimized_model = GridSearchCV(estimator=model, param_grid=params.search_params)
+ optimized_model.fit(train_x, train_y)
+ model = optimized_model.best_estimator_
+ print(optimized_model.cv_results_)
+ else:
+ model.fit(train_x, train_y)
+ return model
diff --git a/Project-BasicAlgorithm/train.py b/Project-BasicAlgorithm/train.py
new file mode 100644
index 0000000..31aae37
--- /dev/null
+++ b/Project-BasicAlgorithm/train.py
@@ -0,0 +1,84 @@
+import click
+import mlflow
+import mlflow.sklearn
+
+from core.data import load_data
+
+
+def get_training_func(algorithm):
+ if algorithm == "svm":
+ from core.training.svm import train_svc as training_func
+
+ elif algorithm == "lightgbm":
+ from core.training.lightgbm import train_lightgbm as training_func
+
+ elif algorithm == "xgboost":
+ from core.training.xgboost import train_xgboost as training_func
+
+ elif algorithm == "lr":
+ from core.training.lr import train_lr as training_func
+
+ else:
+ assert f"{algorithm} not supported"
+
+ return training_func
+
+
+def create_model_version(model_name, run_id=None, auto_replace=True):
+ client = mlflow.tracking.MlflowClient()
+ filter_string = "name='{}'".format(model_name)
+ versions = client.search_model_versions(filter_string)
+
+ if not versions:
+ client.create_registered_model(model_name)
+
+ # TODO: 根据与上一个version对比来判断是否更换Production的模型版本
+ for version in versions:
+ if version.current_stage == "Production":
+ client.transition_model_version_stage(
+ model_name, version=version.version, stage="Archived"
+ )
+
+ uri = f"runs:/{run_id}/sklearn_model"
+ mv = mlflow.register_model(uri, model_name)
+ client.transition_model_version_stage(
+ model_name, version=mv.version, stage="Production"
+ )
+ return versions
+
+
+@click.command()
+@click.option("--algorithm")
+@click.option("--data_path")
+@click.option("--label_column", default="label")
+@click.option("--model_name", default=None)
+@click.option("--random_state", default=0)
+@click.option("--param_file", default=None)
+@click.option("--params", default=None)
+@click.option("--search_params", default=None)
+def main(algorithm, data_path, label_column, model_name, random_state, param_file, params, search_params):
+
+ train_x, train_y, test_x, test_y = load_data(
+ data_path, label_column, random_state=random_state
+ )
+ training_func = get_training_func(algorithm)
+
+ with mlflow.start_run() as run:
+ model, metrics = training_func(train_x,
+ train_y,
+ test_x,
+ test_y,
+ param_file=param_file,
+ params=params,
+ search_params=search_params,
+ )
+ print(metrics)
+ mlflow.log_metrics(metrics)
+ mlflow.sklearn.log_model(model, artifact_path="sklearn_model")
+
+ if model_name:
+ create_model_version(model_name, run_id=run.info.run_id)
+
+
+if __name__ == "__main__":
+ main()