You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/02/24 09:26:40 UTC
[flink-ml] branch master updated: [FLINK-26267][ml][python] Add common params interface in ML Python API
This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new d14e1f4 [FLINK-26267][ml][python] Add common params interface in ML Python API
d14e1f4 is described below
commit d14e1f481ec3e8b154a46a003e3a4fd59614f34a
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Wed Feb 23 11:24:49 2022 +0800
[FLINK-26267][ml][python] Add common params interface in ML Python API
This closes #65.
---
flink-ml-python/pyflink/ml/core/api.py | 4 +-
flink-ml-python/pyflink/ml/core/param.py | 18 +-
flink-ml-python/pyflink/ml/core/wrapper.py | 208 ++++++++++++++
flink-ml-python/pyflink/ml/lib/param.py | 371 +++++++++++++++++++++++++
flink-ml-python/pyflink/ml/tests/test_param.py | 201 ++++++++++++++
5 files changed, 796 insertions(+), 6 deletions(-)
diff --git a/flink-ml-python/pyflink/ml/core/api.py b/flink-ml-python/pyflink/ml/core/api.py
index 541dffc..9358303 100644
--- a/flink-ml-python/pyflink/ml/core/api.py
+++ b/flink-ml-python/pyflink/ml/core/api.py
@@ -87,10 +87,10 @@ class Model(Transformer[T], ABC):
the extra APIs to set and get model data.
"""
- def set_model_data(self, *inputs: Table) -> None:
+ def set_model_data(self, *inputs: Table) -> 'Model':
raise Exception("This operation is not supported.")
- def get_model_data(self) -> None:
+ def get_model_data(self) -> List[Table]:
"""
Gets a list of tables representing the model data. Each table could be an unbounded stream
of model data changes.
diff --git a/flink-ml-python/pyflink/ml/core/param.py b/flink-ml-python/pyflink/ml/core/param.py
index c9da75a..d6df601 100644
--- a/flink-ml-python/pyflink/ml/core/param.py
+++ b/flink-ml-python/pyflink/ml/core/param.py
@@ -216,6 +216,18 @@ class ParamValidators(object):
return NotNull()
+ @staticmethod
+ def non_empty_array() -> ParamValidator[List[T]]:
+ """
+ Checks if the parameter value is not empty array.
+ """
+
+ class NonEmptyArray(ParamValidator[List[T]]):
+ def validate(self, value: List[T]) -> bool:
+ return value is not None and len(value) > 0
+
+ return NonEmptyArray()
+
class Param(Generic[T]):
"""
@@ -233,8 +245,7 @@ class Param(Generic[T]):
if default_value is not None and not validator.validate(default_value):
raise ValueError(f"Parameter {name} is given an invalid value {default_value}")
- @staticmethod
- def json_encode(value: T) -> str:
+ def json_encode(self, value: T) -> str:
"""
Encodes the given object into a json-formatted string.
@@ -243,8 +254,7 @@ class Param(Generic[T]):
"""
return str(jsonpickle.encode(value, keys=True))
- @staticmethod
- def json_decode(json: str) -> T:
+ def json_decode(self, json: str) -> T:
"""
Decodes the given string into an object of class type T.
diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py b/flink-ml-python/pyflink/ml/core/wrapper.py
new file mode 100644
index 0000000..789e652
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/core/wrapper.py
@@ -0,0 +1,208 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any
+
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.java_gateway import get_gateway
+from pyflink.table import Table
+from pyflink.util.java_utils import to_jarray
+
+from pyflink.ml.core.api import Model, Transformer, AlgoOperator, Stage, Estimator
+from pyflink.ml.core.param import Param, WithParams
+
+
+class JavaWrapper(ABC):
+ """
+ Wrapper class for a Java object
+ """
+
+ def __init__(self, java_obj):
+ self._java_obj = java_obj
+
+
+class JavaWithParams(WithParams, JavaWrapper):
+ """
+ Wrapper class for a Java WithParams.
+ """
+ PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME = {
+ 'distance_measure': 'distanceMeasure',
+ 'features_col': 'featuresCol',
+ 'global_batch_size': 'globalBatchSize',
+ 'handle_invalid': 'handleInvalid',
+ 'input_cols': 'inputCols',
+ 'label_col': 'labelCol',
+ 'learning_rate': 'learningRate',
+ 'max_iter': 'maxIter',
+ 'multi_class': 'multiClass',
+ 'output_cols': 'outputCols',
+ 'prediction_col': 'predictionCol',
+ 'raw_prediction_col': 'rawPredictionCol',
+ 'reg': 'reg',
+ 'seed': 'seed',
+ 'tol': 'tol',
+ 'weight_col': 'weightCol'
+ }
+
+ def __init__(self, java_params):
+ super(JavaWithParams, self).__init__(java_params)
+
+ def set(self, param: Param, value) -> WithParams:
+ java_param_name = self._to_java_param_name(param.name)
+ set_method_name = ''.join(['set', java_param_name[0].upper(), java_param_name[1:]])
+ getattr(self._java_obj, set_method_name)(value)
+ return self
+
+ def get(self, param: Param):
+ java_param_name = self._to_java_param_name(param.name)
+ get_method_name = ''.join(['get', java_param_name[0].upper(), java_param_name[1:]])
+ return getattr(self._java_obj, get_method_name)()
+
+ def get_param_map(self) -> Dict[Param, Any]:
+ return self._java_obj.getParamMap()
+
+ def _to_java_param_name(self, name):
+ if name in self.PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME:
+ return self.PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME[name]
+ else:
+ raise Exception('Unknown param exception %s' % name)
+
+
+class JavaStage(Stage, JavaWithParams, ABC):
+ """
+ Wrapper class for a Java Stage.
+ """
+
+ def __init__(self, java_stage):
+ super(JavaStage, self).__init__(java_stage)
+
+ def save(self, path: str) -> None:
+ self._java_obj.save(path)
+
+
+class JavaAlgoOperator(AlgoOperator, JavaStage, ABC):
+ """
+ Wrapper class for a Java AlgoOperator.
+ """
+
+ def __init__(self, java_algo_operator):
+ super(JavaAlgoOperator, self).__init__(java_algo_operator)
+
+ def transform(self, *inputs: Table) -> List[Table]:
+ results = self._java_obj.transform(_to_java_tables(*inputs))
+ return [Table(t, inputs[0]._t_env) for t in results]
+
+
+class JavaTransformer(Transformer, JavaAlgoOperator, ABC):
+ """
+ Wrapper class for a Java Transformer.
+ """
+
+ def __init__(self, java_transformer):
+ super(JavaTransformer, self).__init__(java_transformer)
+
+
+class JavaModel(Model, JavaTransformer, ABC):
+ """
+ Wrapper class for a Java Model.
+ """
+
+ def __init__(self, java_model):
+ if java_model is None:
+ super(JavaModel, self).__init__(_to_java_reference(self._java_model_path())())
+ else:
+ super(JavaModel, self).__init__(java_model)
+ self._t_env = None
+
+ def set_model_data(self, *inputs: Table) -> Model:
+ self._t_env = inputs[0]._t_env
+ self._java_obj.setModelData(_to_java_tables(*inputs))
+ return self
+
+ def get_model_data(self) -> List[Table]:
+ return [Table(t, self._t_env) for t in self._java_obj.getModelData()]
+
+ @classmethod
+ def load(cls, env: StreamExecutionEnvironment, path: str):
+ java_model = _to_java_reference(cls._java_model_path()).load(
+ env._j_stream_execution_environment, path)
+ instance = cls(java_model)
+ return instance
+
+ @classmethod
+ @abstractmethod
+ def _java_model_path(cls) -> str:
+ pass
+
+
+class JavaEstimator(Estimator, JavaStage, ABC):
+ """
+ Wrapper class for a Java Estimator.
+ """
+
+ def __init__(self):
+ super(JavaEstimator, self).__init__(_new_java_obj(self._java_estimator_path()))
+
+ def fit(self, *inputs: Table) -> Model:
+ return self._create_model(self._java_obj.fit(_to_java_tables(*inputs)))
+
+ @classmethod
+ def _create_model(cls, java_model) -> Model:
+ """
+ Creates a model from the input Java model reference.
+ """
+ pass
+
+ @classmethod
+ def load(cls, env: StreamExecutionEnvironment, path: str):
+ """
+ Instantiates a new stage instance based on the data read from the given path.
+ """
+ java_estimator = _to_java_reference(cls._java_estimator_path()).load(
+ env._j_stream_execution_environment, path)
+ instance = cls()
+ instance._java_obj = java_estimator
+ return instance
+
+ @classmethod
+ @abstractmethod
+ def _java_estimator_path(cls) -> str:
+ pass
+
+
+def _to_java_reference(java_class: str):
+ java_obj = get_gateway().jvm
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return java_obj
+
+
+def _new_java_obj(java_class: str, *java_args):
+ """
+ Returns a new Java object.
+ """
+ java_obj = _to_java_reference(java_class)
+ return java_obj(*java_args)
+
+
+def _to_java_tables(*inputs: Table):
+ """
+ Converts Python Tables to Java tables.
+ """
+ gateway = get_gateway()
+ return to_jarray(gateway.jvm.org.apache.flink.table.api.Table, [t._j_table for t in inputs])
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
new file mode 100644
index 0000000..a24902c
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -0,0 +1,371 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC
+from typing import List
+
+from pyflink.ml.core.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
+ StringArrayParam, FloatParam
+
+
+class HasDistanceMeasure(WithParams, ABC):
+ """
+ Base class for the shared distance_measure param.
+ """
+ DISTANCE_MEASURE: Param[str] = StringParam(
+ "distance_measure",
+ "Distance measure. Supported options: 'euclidean' and 'cosine'.",
+ "euclidean",
+ ParamValidators.in_array(['euclidean', 'cosine']))
+
+ def set_distance_measure(self, distance_measure: str):
+ return self.set(self.DISTANCE_MEASURE, distance_measure)
+
+ def get_distance_measure(self) -> str:
+ return self.get(self.DISTANCE_MEASURE)
+
+ @property
+ def distance_measure(self) -> str:
+ return self.get_distance_measure()
+
+
+class HasFeaturesCol(WithParams, ABC):
+ """
+ Base class for the shared feature_col param.
+ """
+ FEATURES_COL: Param[str] = StringParam(
+ "features_col",
+ "Features column name.",
+ "features",
+ ParamValidators.not_null())
+
+ def set_features_col(self, col):
+ return self.set(self.FEATURES_COL, col)
+
+ def get_features_col(self) -> str:
+ return self.get(self.FEATURES_COL)
+
+ @property
+ def features_col(self) -> str:
+ return self.get_features_col()
+
+
+class HasGlobalBatchSize(WithParams, ABC):
+ """
+ Base class for the shared global_batch_size param.
+ """
+ GLOBAL_BATCH_SIZE: Param[int] = IntParam(
+ "global_batch_size",
+ "Global batch size of training algorithms.",
+ 32,
+ ParamValidators.gt(0))
+
+ def set_global_batch_size(self, global_batch_size: int):
+ return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)
+
+ def get_global_batch_size(self) -> int:
+ return self.get(self.GLOBAL_BATCH_SIZE)
+
+ @property
+ def global_batch_size(self) -> int:
+ return self.get_global_batch_size()
+
+
+class HasHandleInvalid(WithParams, ABC):
+ """
+ Base class for the shared handle_invalid param.
+
+ Supported options and the corresponding behavior to handle invalid entries is listed as follows.
+
+ <ul>
+ <li>error: raise an exception.
+ <li>skip: filter out rows with bad values.
+ </ul>
+ """
+ HANDLE_INVALID: Param[str] = StringParam(
+ "handle_invalid",
+ "Strategy to handle invalid entries.",
+ "error",
+ ParamValidators.in_array(['error', 'skip']))
+
+ def set_handle_invalid(self, value: str):
+ return self.set(self.HANDLE_INVALID, value)
+
+ def get_handle_invalid(self) -> str:
+ return self.get(self.HANDLE_INVALID)
+
+ @property
+ def handle_invalid(self) -> str:
+ return self.get_handle_invalid()
+
+
+class HasInputCols(WithParams, ABC):
+ """
+ Base class for the shared input cols param.
+ """
+ INPUT_COLS: Param[List[str]] = StringArrayParam(
+ "input_cols",
+ "Input column names.",
+ None,
+ ParamValidators.non_empty_array())
+
+ def set_input_cols(self, cols: List[str]):
+ return self.set(self.INPUT_COLS, cols)
+
+ def get_input_cols(self) -> List[str]:
+ return self.get(self.INPUT_COLS)
+
+ @property
+ def input_cols(self) -> List[str]:
+ return self.get_input_cols()
+
+
+class HasLabelCol(WithParams, ABC):
+ """
+ Base class for the shared label column param.
+ """
+ LABEL_COL: Param[str] = StringParam(
+ "label_col",
+ "Label column name.",
+ "label",
+ ParamValidators.not_null())
+
+ def set_label_col(self, col: str):
+ return self.set(self.LABEL_COL, col)
+
+ def get_label_col(self) -> str:
+ return self.get(self.LABEL_COL)
+
+ @property
+ def label_col(self) -> str:
+ return self.get_label_col()
+
+
+class HasLearningRate(WithParams, ABC):
+ """
+ Base class for the shared learning rate param.
+ """
+
+ LEARNING_RATE: Param[float] = FloatParam(
+ "learning_rate",
+ "Learning rate of optimization method.",
+ 0.1,
+ ParamValidators.gt(0))
+
+ def set_learning_rate(self, learning_rate: float):
+ return self.set(self.LEARNING_RATE, learning_rate)
+
+ def get_learning_rate(self) -> float:
+ return self.get(self.LEARNING_RATE)
+
+ @property
+ def learning_rate(self) -> float:
+ return self.get_learning_rate()
+
+
+class HasMaxIter(WithParams, ABC):
+ """
+ Base class for the shared maxIter param.
+ """
+ MAX_ITER: Param[int] = IntParam(
+ "max_iter",
+ "Maximum number of iterations.",
+ 20,
+ ParamValidators.gt(0))
+
+ def set_max_iter(self, max_iter: int):
+ return self.set(self.MAX_ITER, max_iter)
+
+ def get_max_iter(self) -> int:
+ return self.get(self.MAX_ITER)
+
+ @property
+ def max_iter(self) -> int:
+ return self.get_max_iter()
+
+
+class HasMultiClass(WithParams, ABC):
+ """
+ Base class for the shared multi class param.
+
+ Supported options:
+ <li>auto: selects the classification type based on the number of classes:
+ If the number of unique label values from the input data is one or two,
+ set to "binomial". Otherwise, set to "multinomial".
+ <li>binomial: binary logistic regression.
+ <li>multinomial: multinomial logistic regression.
+ """
+ MULTI_CLASS: Param[str] = StringParam(
+ "multi_class",
+ "Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
+ 'auto',
+ ParamValidators.in_array(['auto', 'binomial', 'multinomial']))
+
+ def set_multi_class(self, class_type: str):
+ return self.set(self.MULTI_CLASS, class_type)
+
+ def get_multi_class(self) -> str:
+ return self.get(self.MULTI_CLASS)
+
+ @property
+ def multi_class(self) -> str:
+ return self.get_multi_class()
+
+
+class HasOutputCols(WithParams, ABC):
+ """
+ Base class for the shared output_cols param.
+ """
+ OUTPUT_COLS: Param[List[str]] = StringArrayParam(
+ "output_cols",
+ "Output column names.",
+ None,
+ ParamValidators.non_empty_array())
+
+ def set_output_cols(self, cols: List[str]):
+ return self.set(self.OUTPUT_COLS, cols)
+
+ def get_output_cols(self) -> List[str]:
+ return self.get(self.OUTPUT_COLS)
+
+ @property
+ def output_cols(self) -> List[str]:
+ return self.get_output_cols()
+
+
+class HasPredictionCol(WithParams, ABC):
+ """
+ Base class for the shared prediction column param.
+ """
+ PREDICTION_COL: Param[str] = StringParam(
+ "prediction_col",
+ "Prediction column name.",
+ "prediction",
+ ParamValidators.not_null())
+
+ def set_prediction_col(self, col: str):
+ return self.set(self.PREDICTION_COL, col)
+
+ def get_prediction_col(self) -> str:
+ return self.get(self.PREDICTION_COL)
+
+ @property
+ def prediction_col(self) -> str:
+ return self.get_prediction_col()
+
+
+class HasRawPredictionCol(WithParams, ABC):
+ """
+ Base class for the shared raw prediction column param.
+ """
+ RAW_PREDICTION_COL: Param[str] = StringParam(
+ "raw_prediction_col",
+ "Raw prediction column name.",
+ "raw_prediction")
+
+ def set_raw_prediction_col(self, col: str):
+ return self.set(self.RAW_PREDICTION_COL, col)
+
+ def get_raw_prediction_col(self):
+ return self.get(self.RAW_PREDICTION_COL)
+
+ @property
+ def raw_prediction_col(self) -> str:
+ return self.get_raw_prediction_col()
+
+
+class HasReg(WithParams, ABC):
+ """
+ Base class for the shared regularization param.
+ """
+ REG: Param[float] = FloatParam(
+ "reg",
+ "Regularization parameter.",
+ 0.,
+ ParamValidators.gt_eq(0.))
+
+ def set_reg(self, value: float):
+ return self.set(self.REG, value)
+
+ def get_reg(self) -> float:
+ return self.get(self.REG)
+
+ @property
+ def reg(self) -> float:
+ return self.get_reg()
+
+
+class HasSeed(WithParams, ABC):
+ """
+ Base class for the shared seed param.
+ """
+ SEED: Param[int] = IntParam(
+ "seed",
+ "The random seed.",
+ None)
+
+ def set_seed(self, seed: int):
+ return self.set(self.SEED, seed) if seed is not None else hash(self.__class__.__name__)
+
+ def get_seed(self) -> int:
+ return self.get(self.SEED)
+
+ @property
+ def seed(self) -> int:
+ return self.get_seed()
+
+
+class HasTol(WithParams, ABC):
+ """
+ Base class for the shared tolerance param.
+ """
+ TOL: Param[float] = FloatParam(
+ "tol",
+ "Convergence tolerance for iterative algorithms.",
+ 1e-6,
+ ParamValidators.gt_eq(0))
+
+ def set_tol(self, value: float):
+ return self.set(self.TOL, value)
+
+ def get_tol(self) -> float:
+ return self.get(self.TOL)
+
+ @property
+ def tol(self) -> float:
+ return self.get_tol()
+
+
+class HasWeightCol(WithParams, ABC):
+ """
+ Base class for the shared weight column param. If this is not set, we treat all instance weights
+ as 1.0.
+ """
+ WEIGHT_COL: Param[str] = StringParam(
+ "weight_col",
+ "Weight column name.",
+ None)
+
+ def set_weight_col(self, col: str):
+ return self.set(self.WEIGHT_COL, col)
+
+ def get_weight_col(self) -> str:
+ return self.get(self.WEIGHT_COL)
+
+ @property
+ def weight_col(self):
+ return self.get_weight_col()
diff --git a/flink-ml-python/pyflink/ml/tests/test_param.py b/flink-ml-python/pyflink/ml/tests/test_param.py
new file mode 100644
index 0000000..f9f98d2
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/tests/test_param.py
@@ -0,0 +1,201 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import unittest
+from typing import Dict, Any
+
+from pyflink.ml.core.param import Param
+from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \
+ HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \
+ HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol
+
+
+class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid,
+ HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass,
+ HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol,
+ HasWeightCol):
+ def __init__(self):
+ self._param_map = {}
+
+ def get_param_map(self) -> Dict['Param[Any]', Any]:
+ return self._param_map
+
+
+class ParamTests(unittest.TestCase):
+ def test_distance_measure_param(self):
+ param = TestParams()
+ distance_measure = param.DISTANCE_MEASURE
+ self.assertEqual(distance_measure.name, "distance_measure")
+ self.assertEqual(distance_measure.description,
+ "Distance measure. Supported options: 'euclidean' and 'cosine'.")
+ self.assertEqual(distance_measure.default_value, "euclidean")
+
+ param.set_distance_measure("cosine")
+ self.assertEqual(param.get_distance_measure(), "cosine")
+
+ def test_feature_col_param(self):
+ param = TestParams()
+ feature_col = param.FEATURES_COL
+ self.assertEqual(feature_col.name, "features_col")
+ self.assertEqual(feature_col.description, "Features column name.")
+ self.assertEqual(feature_col.default_value, "features")
+
+ param.set_features_col("test_features")
+ self.assertEqual(param.get_features_col(), "test_features")
+
+ def test_global_batch_size_param(self):
+ param = TestParams()
+ global_batch_size = param.GLOBAL_BATCH_SIZE
+ self.assertEqual(global_batch_size.name, "global_batch_size")
+ self.assertEqual(global_batch_size.description,
+ "Global batch size of training algorithms.")
+ self.assertEqual(global_batch_size.default_value, 32)
+
+ param.set_global_batch_size(100)
+ self.assertEqual(param.get_global_batch_size(), 100)
+
+ def test_handle_invalid_param(self):
+ param = TestParams()
+ handle_invalid = param.HANDLE_INVALID
+ self.assertEqual(handle_invalid.name, "handle_invalid")
+ self.assertEqual(handle_invalid.description, "Strategy to handle invalid entries.")
+ self.assertEqual(handle_invalid.default_value, "error")
+
+ param.set_handle_invalid("skip")
+ self.assertEqual(param.get_handle_invalid(), "skip")
+
+ def test_input_cols_param(self):
+ param = TestParams()
+ input_cols = param.INPUT_COLS
+ self.assertEqual(input_cols.name, "input_cols")
+ self.assertEqual(input_cols.description, "Input column names.")
+ self.assertEqual(input_cols.default_value, None)
+
+ param.set_input_cols(['a', 'b', 'c'])
+ self.assertEqual(param.get_input_cols(), ['a', 'b', 'c'])
+
+ def test_label_col_param(self):
+ param = TestParams()
+ label_col = param.LABEL_COL
+ self.assertEqual(label_col.name, "label_col")
+ self.assertEqual(label_col.description, "Label column name.")
+ self.assertEqual(label_col.default_value, "label")
+
+ param.set_label_col('test_label')
+ self.assertEqual(param.get_label_col(), 'test_label')
+
+ def test_learning_rate_param(self):
+ param = TestParams()
+ learning_rate = param.LEARNING_RATE
+ self.assertEqual(learning_rate.name, "learning_rate")
+ self.assertEqual(learning_rate.description, "Learning rate of optimization method.")
+ self.assertEqual(learning_rate.default_value, 0.1)
+
+ param.set_learning_rate(0.2)
+ self.assertEqual(param.get_learning_rate(), 0.2)
+
+ def test_max_iter_param(self):
+ param = TestParams()
+ max_iter = param.MAX_ITER
+ self.assertEqual(max_iter.name, "max_iter")
+ self.assertEqual(max_iter.description, "Maximum number of iterations.")
+ self.assertEqual(max_iter.default_value, 20)
+
+ param.set_max_iter(50)
+ self.assertEqual(param.get_max_iter(), 50)
+
+ def test_multi_class_param(self):
+ param = TestParams()
+ multi_class = param.MULTI_CLASS
+ self.assertEqual(multi_class.name, "multi_class")
+ self.assertEqual(multi_class.description,
+ "Classification type. Supported options: "
+ "'auto', 'binomial' and 'multinomial'.")
+ self.assertEqual(multi_class.default_value, 'auto')
+
+ param.set_multi_class('binomial')
+ self.assertEqual(param.get_multi_class(), 'binomial')
+
+ def test_output_cols_param(self):
+ param = TestParams()
+ output_cols = param.OUTPUT_COLS
+ self.assertEqual(output_cols.name, "output_cols")
+ self.assertEqual(output_cols.description, "Output column names.")
+ self.assertEqual(output_cols.default_value, None)
+
+ param.set_output_cols(['a', 'b'])
+ self.assertEqual(param.get_output_cols(), ['a', 'b'])
+
+ def test_prediction_col_param(self):
+ param = TestParams()
+ prediction_col = param.PREDICTION_COL
+ self.assertEqual(prediction_col.name, "prediction_col")
+ self.assertEqual(prediction_col.description, "Prediction column name.")
+ self.assertEqual(prediction_col.default_value, "prediction")
+
+ param.set_prediction_col('test_prediction')
+ self.assertEqual(param.get_prediction_col(), 'test_prediction')
+
+ def test_raw_prediction_col_param(self):
+ param = TestParams()
+ raw_prediction_col = param.RAW_PREDICTION_COL
+ self.assertEqual(raw_prediction_col.name, "raw_prediction_col")
+ self.assertEqual(raw_prediction_col.description, "Raw prediction column name.")
+ self.assertEqual(raw_prediction_col.default_value, "raw_prediction")
+
+ param.set_raw_prediction_col('test_raw_prediction')
+ self.assertEqual(param.get_raw_prediction_col(), 'test_raw_prediction')
+
+ def test_reg_param(self):
+ param = TestParams()
+ reg = param.REG
+ self.assertEqual(reg.name, "reg")
+ self.assertEqual(reg.description, "Regularization parameter.")
+ self.assertEqual(reg.default_value, 0.)
+
+ param.set_reg(0.4)
+ self.assertEqual(param.get_reg(), 0.4)
+
+ def test_seed_param(self):
+ param = TestParams()
+ seed = param.SEED
+ self.assertEqual(seed.name, "seed")
+ self.assertEqual(seed.description, "The random seed.")
+ self.assertEqual(seed.default_value, None)
+
+ param.set_seed(1)
+ self.assertEqual(param.get_seed(), 1)
+
+ def test_tol(self):
+ param = TestParams()
+ tol = param.TOL
+ self.assertEqual(tol.name, "tol")
+ self.assertEqual(tol.description, "Convergence tolerance for iterative algorithms.")
+ self.assertEqual(tol.default_value, 1e-6)
+
+ param.set_tol(1e-5)
+ self.assertEqual(param.get_tol(), 1e-5)
+
+ def test_weight_col(self):
+ param = TestParams()
+ weight_col = param.WEIGHT_COL
+ self.assertEqual(weight_col.name, "weight_col")
+ self.assertEqual(weight_col.description, "Weight column name.")
+ self.assertEqual(weight_col.default_value, None)
+
+ param.set_weight_col('test_weight_col')
+ self.assertEqual(param.get_weight_col(), 'test_weight_col')