You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/02/24 09:26:40 UTC

[flink-ml] branch master updated: [FLINK-26267][ml][python] Add common params interface in ML Python API

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new d14e1f4  [FLINK-26267][ml][python] Add common params interface in ML Python API
d14e1f4 is described below

commit d14e1f481ec3e8b154a46a003e3a4fd59614f34a
Author: huangxingbo <hx...@gmail.com>
AuthorDate: Wed Feb 23 11:24:49 2022 +0800

    [FLINK-26267][ml][python] Add common params interface in ML Python API
    
    This closes #65.
---
 flink-ml-python/pyflink/ml/core/api.py         |   4 +-
 flink-ml-python/pyflink/ml/core/param.py       |  18 +-
 flink-ml-python/pyflink/ml/core/wrapper.py     | 208 ++++++++++++++
 flink-ml-python/pyflink/ml/lib/param.py        | 371 +++++++++++++++++++++++++
 flink-ml-python/pyflink/ml/tests/test_param.py | 201 ++++++++++++++
 5 files changed, 796 insertions(+), 6 deletions(-)

diff --git a/flink-ml-python/pyflink/ml/core/api.py b/flink-ml-python/pyflink/ml/core/api.py
index 541dffc..9358303 100644
--- a/flink-ml-python/pyflink/ml/core/api.py
+++ b/flink-ml-python/pyflink/ml/core/api.py
@@ -87,10 +87,10 @@ class Model(Transformer[T], ABC):
     the extra APIs to set and get model data.
     """
 
-    def set_model_data(self, *inputs: Table) -> None:
+    def set_model_data(self, *inputs: Table) -> 'Model':
         raise Exception("This operation is not supported.")
 
-    def get_model_data(self) -> None:
+    def get_model_data(self) -> List[Table]:
         """
         Gets a list of tables representing the model data. Each table could be an unbounded stream
         of model data changes.
diff --git a/flink-ml-python/pyflink/ml/core/param.py b/flink-ml-python/pyflink/ml/core/param.py
index c9da75a..d6df601 100644
--- a/flink-ml-python/pyflink/ml/core/param.py
+++ b/flink-ml-python/pyflink/ml/core/param.py
@@ -216,6 +216,18 @@ class ParamValidators(object):
 
         return NotNull()
 
+    @staticmethod
+    def non_empty_array() -> ParamValidator[List[T]]:
+        """
+        Checks if the parameter value is not empty array.
+        """
+
+        class NonEmptyArray(ParamValidator[List[T]]):
+            def validate(self, value: List[T]) -> bool:
+                return value is not None and len(value) > 0
+
+        return NonEmptyArray()
+
 
 class Param(Generic[T]):
     """
@@ -233,8 +245,7 @@ class Param(Generic[T]):
         if default_value is not None and not validator.validate(default_value):
             raise ValueError(f"Parameter {name} is given an invalid value {default_value}")
 
-    @staticmethod
-    def json_encode(value: T) -> str:
+    def json_encode(self, value: T) -> str:
         """
         Encodes the given object into a json-formatted string.
 
@@ -243,8 +254,7 @@ class Param(Generic[T]):
         """
         return str(jsonpickle.encode(value, keys=True))
 
-    @staticmethod
-    def json_decode(json: str) -> T:
+    def json_decode(self, json: str) -> T:
         """
         Decodes the given string into an object of class type T.
 
diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py b/flink-ml-python/pyflink/ml/core/wrapper.py
new file mode 100644
index 0000000..789e652
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/core/wrapper.py
@@ -0,0 +1,208 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC, abstractmethod
+from typing import List, Dict, Any
+
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.java_gateway import get_gateway
+from pyflink.table import Table
+from pyflink.util.java_utils import to_jarray
+
+from pyflink.ml.core.api import Model, Transformer, AlgoOperator, Stage, Estimator
+from pyflink.ml.core.param import Param, WithParams
+
+
+class JavaWrapper(ABC):
+    """
+    Wrapper class for a Java object
+    """
+
+    def __init__(self, java_obj):
+        self._java_obj = java_obj
+
+
+class JavaWithParams(WithParams, JavaWrapper):
+    """
+    Wrapper class for a Java WithParams.
+    """
+    PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME = {
+        'distance_measure': 'distanceMeasure',
+        'features_col': 'featuresCol',
+        'global_batch_size': 'globalBatchSize',
+        'handle_invalid': 'handleInvalid',
+        'input_cols': 'inputCols',
+        'label_col': 'labelCol',
+        'learning_rate': 'learningRate',
+        'max_iter': 'maxIter',
+        'multi_class': 'multiClass',
+        'output_cols': 'outputCols',
+        'prediction_col': 'predictionCol',
+        'raw_prediction_col': 'rawPredictionCol',
+        'reg': 'reg',
+        'seed': 'seed',
+        'tol': 'tol',
+        'weight_col': 'weightCol'
+    }
+
+    def __init__(self, java_params):
+        super(JavaWithParams, self).__init__(java_params)
+
+    def set(self, param: Param, value) -> WithParams:
+        java_param_name = self._to_java_param_name(param.name)
+        set_method_name = ''.join(['set', java_param_name[0].upper(), java_param_name[1:]])
+        getattr(self._java_obj, set_method_name)(value)
+        return self
+
+    def get(self, param: Param):
+        java_param_name = self._to_java_param_name(param.name)
+        get_method_name = ''.join(['get', java_param_name[0].upper(), java_param_name[1:]])
+        return getattr(self._java_obj, get_method_name)()
+
+    def get_param_map(self) -> Dict[Param, Any]:
+        return self._java_obj.getParamMap()
+
+    def _to_java_param_name(self, name):
+        if name in self.PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME:
+            return self.PYTHON_PARAM_NAME_TO_JAVA_PARM_NAME[name]
+        else:
+            raise Exception('Unknown param exception %s' % name)
+
+
+class JavaStage(Stage, JavaWithParams, ABC):
+    """
+    Wrapper class for a Java Stage.
+    """
+
+    def __init__(self, java_stage):
+        super(JavaStage, self).__init__(java_stage)
+
+    def save(self, path: str) -> None:
+        self._java_obj.save(path)
+
+
+class JavaAlgoOperator(AlgoOperator, JavaStage, ABC):
+    """
+    Wrapper class for a Java AlgoOperator.
+    """
+
+    def __init__(self, java_algo_operator):
+        super(JavaAlgoOperator, self).__init__(java_algo_operator)
+
+    def transform(self, *inputs: Table) -> List[Table]:
+        results = self._java_obj.transform(_to_java_tables(*inputs))
+        return [Table(t, inputs[0]._t_env) for t in results]
+
+
+class JavaTransformer(Transformer, JavaAlgoOperator, ABC):
+    """
+    Wrapper class for a Java Transformer.
+    """
+
+    def __init__(self, java_transformer):
+        super(JavaTransformer, self).__init__(java_transformer)
+
+
+class JavaModel(Model, JavaTransformer, ABC):
+    """
+    Wrapper class for a Java Model.
+    """
+
+    def __init__(self, java_model):
+        if java_model is None:
+            super(JavaModel, self).__init__(_to_java_reference(self._java_model_path())())
+        else:
+            super(JavaModel, self).__init__(java_model)
+        self._t_env = None
+
+    def set_model_data(self, *inputs: Table) -> Model:
+        self._t_env = inputs[0]._t_env
+        self._java_obj.setModelData(_to_java_tables(*inputs))
+        return self
+
+    def get_model_data(self) -> List[Table]:
+        return [Table(t, self._t_env) for t in self._java_obj.getModelData()]
+
+    @classmethod
+    def load(cls, env: StreamExecutionEnvironment, path: str):
+        java_model = _to_java_reference(cls._java_model_path()).load(
+            env._j_stream_execution_environment, path)
+        instance = cls(java_model)
+        return instance
+
+    @classmethod
+    @abstractmethod
+    def _java_model_path(cls) -> str:
+        pass
+
+
+class JavaEstimator(Estimator, JavaStage, ABC):
+    """
+    Wrapper class for a Java Estimator.
+    """
+
+    def __init__(self):
+        super(JavaEstimator, self).__init__(_new_java_obj(self._java_estimator_path()))
+
+    def fit(self, *inputs: Table) -> Model:
+        return self._create_model(self._java_obj.fit(_to_java_tables(*inputs)))
+
+    @classmethod
+    def _create_model(cls, java_model) -> Model:
+        """
+        Creates a model from the input Java model reference.
+        """
+        pass
+
+    @classmethod
+    def load(cls, env: StreamExecutionEnvironment, path: str):
+        """
+        Instantiates a new stage instance based on the data read from the given path.
+        """
+        java_estimator = _to_java_reference(cls._java_estimator_path()).load(
+            env._j_stream_execution_environment, path)
+        instance = cls()
+        instance._java_obj = java_estimator
+        return instance
+
+    @classmethod
+    @abstractmethod
+    def _java_estimator_path(cls) -> str:
+        pass
+
+
+def _to_java_reference(java_class: str):
+    java_obj = get_gateway().jvm
+    for name in java_class.split("."):
+        java_obj = getattr(java_obj, name)
+    return java_obj
+
+
+def _new_java_obj(java_class: str, *java_args):
+    """
+    Returns a new Java object.
+    """
+    java_obj = _to_java_reference(java_class)
+    return java_obj(*java_args)
+
+
+def _to_java_tables(*inputs: Table):
+    """
+    Converts Python Tables to Java tables.
+    """
+    gateway = get_gateway()
+    return to_jarray(gateway.jvm.org.apache.flink.table.api.Table, [t._j_table for t in inputs])
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
new file mode 100644
index 0000000..a24902c
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -0,0 +1,371 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC
+from typing import List
+
+from pyflink.ml.core.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
+    StringArrayParam, FloatParam
+
+
+class HasDistanceMeasure(WithParams, ABC):
+    """
+    Base class for the shared distance_measure param.
+    """
+    DISTANCE_MEASURE: Param[str] = StringParam(
+        "distance_measure",
+        "Distance measure. Supported options: 'euclidean' and 'cosine'.",
+        "euclidean",
+        ParamValidators.in_array(['euclidean', 'cosine']))
+
+    def set_distance_measure(self, distance_measure: str):
+        return self.set(self.DISTANCE_MEASURE, distance_measure)
+
+    def get_distance_measure(self) -> str:
+        return self.get(self.DISTANCE_MEASURE)
+
+    @property
+    def distance_measure(self) -> str:
+        return self.get_distance_measure()
+
+
+class HasFeaturesCol(WithParams, ABC):
+    """
+    Base class for the shared feature_col param.
+    """
+    FEATURES_COL: Param[str] = StringParam(
+        "features_col",
+        "Features column name.",
+        "features",
+        ParamValidators.not_null())
+
+    def set_features_col(self, col):
+        return self.set(self.FEATURES_COL, col)
+
+    def get_features_col(self) -> str:
+        return self.get(self.FEATURES_COL)
+
+    @property
+    def features_col(self) -> str:
+        return self.get_features_col()
+
+
+class HasGlobalBatchSize(WithParams, ABC):
+    """
+    Base class for the shared global_batch_size param.
+    """
+    GLOBAL_BATCH_SIZE: Param[int] = IntParam(
+        "global_batch_size",
+        "Global batch size of training algorithms.",
+        32,
+        ParamValidators.gt(0))
+
+    def set_global_batch_size(self, global_batch_size: int):
+        return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)
+
+    def get_global_batch_size(self) -> int:
+        return self.get(self.GLOBAL_BATCH_SIZE)
+
+    @property
+    def global_batch_size(self) -> int:
+        return self.get_global_batch_size()
+
+
+class HasHandleInvalid(WithParams, ABC):
+    """
+    Base class for the shared handle_invalid param.
+
+    Supported options and the corresponding behavior to handle invalid entries is listed as follows.
+
+    <ul>
+        <li>error: raise an exception.
+        <li>skip: filter out rows with bad values.
+    </ul>
+    """
+    HANDLE_INVALID: Param[str] = StringParam(
+        "handle_invalid",
+        "Strategy to handle invalid entries.",
+        "error",
+        ParamValidators.in_array(['error', 'skip']))
+
+    def set_handle_invalid(self, value: str):
+        return self.set(self.HANDLE_INVALID, value)
+
+    def get_handle_invalid(self) -> str:
+        return self.get(self.HANDLE_INVALID)
+
+    @property
+    def handle_invalid(self) -> str:
+        return self.get_handle_invalid()
+
+
+class HasInputCols(WithParams, ABC):
+    """
+    Base class for the shared input cols param.
+    """
+    INPUT_COLS: Param[List[str]] = StringArrayParam(
+        "input_cols",
+        "Input column names.",
+        None,
+        ParamValidators.non_empty_array())
+
+    def set_input_cols(self, cols: List[str]):
+        return self.set(self.INPUT_COLS, cols)
+
+    def get_input_cols(self) -> List[str]:
+        return self.get(self.INPUT_COLS)
+
+    @property
+    def input_cols(self) -> List[str]:
+        return self.get_input_cols()
+
+
+class HasLabelCol(WithParams, ABC):
+    """
+    Base class for the shared label column param.
+    """
+    LABEL_COL: Param[str] = StringParam(
+        "label_col",
+        "Label column name.",
+        "label",
+        ParamValidators.not_null())
+
+    def set_label_col(self, col: str):
+        return self.set(self.LABEL_COL, col)
+
+    def get_label_col(self) -> str:
+        return self.get(self.LABEL_COL)
+
+    @property
+    def label_col(self) -> str:
+        return self.get_label_col()
+
+
+class HasLearningRate(WithParams, ABC):
+    """
+    Base class for the shared learning rate param.
+    """
+
+    LEARNING_RATE: Param[float] = FloatParam(
+        "learning_rate",
+        "Learning rate of optimization method.",
+        0.1,
+        ParamValidators.gt(0))
+
+    def set_learning_rate(self, learning_rate: float):
+        return self.set(self.LEARNING_RATE, learning_rate)
+
+    def get_learning_rate(self) -> float:
+        return self.get(self.LEARNING_RATE)
+
+    @property
+    def learning_rate(self) -> float:
+        return self.get_learning_rate()
+
+
+class HasMaxIter(WithParams, ABC):
+    """
+    Base class for the shared maxIter param.
+    """
+    MAX_ITER: Param[int] = IntParam(
+        "max_iter",
+        "Maximum number of iterations.",
+        20,
+        ParamValidators.gt(0))
+
+    def set_max_iter(self, max_iter: int):
+        return self.set(self.MAX_ITER, max_iter)
+
+    def get_max_iter(self) -> int:
+        return self.get(self.MAX_ITER)
+
+    @property
+    def max_iter(self) -> int:
+        return self.get_max_iter()
+
+
+class HasMultiClass(WithParams, ABC):
+    """
+    Base class for the shared multi class param.
+
+    Supported options:
+        <li>auto: selects the classification type based on the number of classes:
+            If the number of unique label values from the input data is one or two,
+            set to "binomial". Otherwise, set to "multinomial".
+        <li>binomial: binary logistic regression.
+        <li>multinomial: multinomial logistic regression.
+    """
+    MULTI_CLASS: Param[str] = StringParam(
+        "multi_class",
+        "Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
+        'auto',
+        ParamValidators.in_array(['auto', 'binomial', 'multinomial']))
+
+    def set_multi_class(self, class_type: str):
+        return self.set(self.MULTI_CLASS, class_type)
+
+    def get_multi_class(self) -> str:
+        return self.get(self.MULTI_CLASS)
+
+    @property
+    def multi_class(self) -> str:
+        return self.get_multi_class()
+
+
+class HasOutputCols(WithParams, ABC):
+    """
+    Base class for the shared output_cols param.
+    """
+    OUTPUT_COLS: Param[List[str]] = StringArrayParam(
+        "output_cols",
+        "Output column names.",
+        None,
+        ParamValidators.non_empty_array())
+
+    def set_output_cols(self, cols: List[str]):
+        return self.set(self.OUTPUT_COLS, cols)
+
+    def get_output_cols(self) -> List[str]:
+        return self.get(self.OUTPUT_COLS)
+
+    @property
+    def output_cols(self) -> List[str]:
+        return self.get_output_cols()
+
+
+class HasPredictionCol(WithParams, ABC):
+    """
+    Base class for the shared prediction column param.
+    """
+    PREDICTION_COL: Param[str] = StringParam(
+        "prediction_col",
+        "Prediction column name.",
+        "prediction",
+        ParamValidators.not_null())
+
+    def set_prediction_col(self, col: str):
+        return self.set(self.PREDICTION_COL, col)
+
+    def get_prediction_col(self) -> str:
+        return self.get(self.PREDICTION_COL)
+
+    @property
+    def prediction_col(self) -> str:
+        return self.get_prediction_col()
+
+
+class HasRawPredictionCol(WithParams, ABC):
+    """
+    Base class for the shared raw prediction column param.
+    """
+    RAW_PREDICTION_COL: Param[str] = StringParam(
+        "raw_prediction_col",
+        "Raw prediction column name.",
+        "raw_prediction")
+
+    def set_raw_prediction_col(self, col: str):
+        return self.set(self.RAW_PREDICTION_COL, col)
+
+    def get_raw_prediction_col(self):
+        return self.get(self.RAW_PREDICTION_COL)
+
+    @property
+    def raw_prediction_col(self) -> str:
+        return self.get_raw_prediction_col()
+
+
+class HasReg(WithParams, ABC):
+    """
+    Base class for the shared regularization param.
+    """
+    REG: Param[float] = FloatParam(
+        "reg",
+        "Regularization parameter.",
+        0.,
+        ParamValidators.gt_eq(0.))
+
+    def set_reg(self, value: float):
+        return self.set(self.REG, value)
+
+    def get_reg(self) -> float:
+        return self.get(self.REG)
+
+    @property
+    def reg(self) -> float:
+        return self.get_reg()
+
+
+class HasSeed(WithParams, ABC):
+    """
+    Base class for the shared seed param.
+    """
+    SEED: Param[int] = IntParam(
+        "seed",
+        "The random seed.",
+        None)
+
+    def set_seed(self, seed: int):
+        return self.set(self.SEED, seed) if seed is not None else hash(self.__class__.__name__)
+
+    def get_seed(self) -> int:
+        return self.get(self.SEED)
+
+    @property
+    def seed(self) -> int:
+        return self.get_seed()
+
+
+class HasTol(WithParams, ABC):
+    """
+    Base class for the shared tolerance param.
+    """
+    TOL: Param[float] = FloatParam(
+        "tol",
+        "Convergence tolerance for iterative algorithms.",
+        1e-6,
+        ParamValidators.gt_eq(0))
+
+    def set_tol(self, value: float):
+        return self.set(self.TOL, value)
+
+    def get_tol(self) -> float:
+        return self.get(self.TOL)
+
+    @property
+    def tol(self) -> float:
+        return self.get_tol()
+
+
+class HasWeightCol(WithParams, ABC):
+    """
+    Base class for the shared weight column param. If this is not set, we treat all instance weights
+    as 1.0.
+    """
+    WEIGHT_COL: Param[str] = StringParam(
+        "weight_col",
+        "Weight column name.",
+        None)
+
+    def set_weight_col(self, col: str):
+        return self.set(self.WEIGHT_COL, col)
+
+    def get_weight_col(self) -> str:
+        return self.get(self.WEIGHT_COL)
+
+    @property
+    def weight_col(self):
+        return self.get_weight_col()
diff --git a/flink-ml-python/pyflink/ml/tests/test_param.py b/flink-ml-python/pyflink/ml/tests/test_param.py
new file mode 100644
index 0000000..f9f98d2
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/tests/test_param.py
@@ -0,0 +1,201 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import unittest
+from typing import Dict, Any
+
+from pyflink.ml.core.param import Param
+from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \
+    HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \
+    HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol
+
+
+class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid,
+                 HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass,
+                 HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol,
+                 HasWeightCol):
+    def __init__(self):
+        self._param_map = {}
+
+    def get_param_map(self) -> Dict['Param[Any]', Any]:
+        return self._param_map
+
+
+class ParamTests(unittest.TestCase):
+    def test_distance_measure_param(self):
+        param = TestParams()
+        distance_measure = param.DISTANCE_MEASURE
+        self.assertEqual(distance_measure.name, "distance_measure")
+        self.assertEqual(distance_measure.description,
+                         "Distance measure. Supported options: 'euclidean' and 'cosine'.")
+        self.assertEqual(distance_measure.default_value, "euclidean")
+
+        param.set_distance_measure("cosine")
+        self.assertEqual(param.get_distance_measure(), "cosine")
+
+    def test_feature_col_param(self):
+        param = TestParams()
+        feature_col = param.FEATURES_COL
+        self.assertEqual(feature_col.name, "features_col")
+        self.assertEqual(feature_col.description, "Features column name.")
+        self.assertEqual(feature_col.default_value, "features")
+
+        param.set_features_col("test_features")
+        self.assertEqual(param.get_features_col(), "test_features")
+
+    def test_global_batch_size_param(self):
+        param = TestParams()
+        global_batch_size = param.GLOBAL_BATCH_SIZE
+        self.assertEqual(global_batch_size.name, "global_batch_size")
+        self.assertEqual(global_batch_size.description,
+                         "Global batch size of training algorithms.")
+        self.assertEqual(global_batch_size.default_value, 32)
+
+        param.set_global_batch_size(100)
+        self.assertEqual(param.get_global_batch_size(), 100)
+
+    def test_handle_invalid_param(self):
+        param = TestParams()
+        handle_invalid = param.HANDLE_INVALID
+        self.assertEqual(handle_invalid.name, "handle_invalid")
+        self.assertEqual(handle_invalid.description, "Strategy to handle invalid entries.")
+        self.assertEqual(handle_invalid.default_value, "error")
+
+        param.set_handle_invalid("skip")
+        self.assertEqual(param.get_handle_invalid(), "skip")
+
+    def test_input_cols_param(self):
+        param = TestParams()
+        input_cols = param.INPUT_COLS
+        self.assertEqual(input_cols.name, "input_cols")
+        self.assertEqual(input_cols.description, "Input column names.")
+        self.assertEqual(input_cols.default_value, None)
+
+        param.set_input_cols(['a', 'b', 'c'])
+        self.assertEqual(param.get_input_cols(), ['a', 'b', 'c'])
+
+    def test_label_col_param(self):
+        param = TestParams()
+        label_col = param.LABEL_COL
+        self.assertEqual(label_col.name, "label_col")
+        self.assertEqual(label_col.description, "Label column name.")
+        self.assertEqual(label_col.default_value, "label")
+
+        param.set_label_col('test_label')
+        self.assertEqual(param.get_label_col(), 'test_label')
+
+    def test_learning_rate_param(self):
+        param = TestParams()
+        learning_rate = param.LEARNING_RATE
+        self.assertEqual(learning_rate.name, "learning_rate")
+        self.assertEqual(learning_rate.description, "Learning rate of optimization method.")
+        self.assertEqual(learning_rate.default_value, 0.1)
+
+        param.set_learning_rate(0.2)
+        self.assertEqual(param.get_learning_rate(), 0.2)
+
+    def test_max_iter_param(self):
+        param = TestParams()
+        max_iter = param.MAX_ITER
+        self.assertEqual(max_iter.name, "max_iter")
+        self.assertEqual(max_iter.description, "Maximum number of iterations.")
+        self.assertEqual(max_iter.default_value, 20)
+
+        param.set_max_iter(50)
+        self.assertEqual(param.get_max_iter(), 50)
+
+    def test_multi_class_param(self):
+        param = TestParams()
+        multi_class = param.MULTI_CLASS
+        self.assertEqual(multi_class.name, "multi_class")
+        self.assertEqual(multi_class.description,
+                         "Classification type. Supported options: "
+                         "'auto', 'binomial' and 'multinomial'.")
+        self.assertEqual(multi_class.default_value, 'auto')
+
+        param.set_multi_class('binomial')
+        self.assertEqual(param.get_multi_class(), 'binomial')
+
+    def test_output_cols_param(self):
+        param = TestParams()
+        output_cols = param.OUTPUT_COLS
+        self.assertEqual(output_cols.name, "output_cols")
+        self.assertEqual(output_cols.description, "Output column names.")
+        self.assertEqual(output_cols.default_value, None)
+
+        param.set_output_cols(['a', 'b'])
+        self.assertEqual(param.get_output_cols(), ['a', 'b'])
+
+    def test_prediction_col_param(self):
+        param = TestParams()
+        prediction_col = param.PREDICTION_COL
+        self.assertEqual(prediction_col.name, "prediction_col")
+        self.assertEqual(prediction_col.description, "Prediction column name.")
+        self.assertEqual(prediction_col.default_value, "prediction")
+
+        param.set_prediction_col('test_prediction')
+        self.assertEqual(param.get_prediction_col(), 'test_prediction')
+
+    def test_raw_prediction_col_param(self):
+        param = TestParams()
+        raw_prediction_col = param.RAW_PREDICTION_COL
+        self.assertEqual(raw_prediction_col.name, "raw_prediction_col")
+        self.assertEqual(raw_prediction_col.description, "Raw prediction column name.")
+        self.assertEqual(raw_prediction_col.default_value, "raw_prediction")
+
+        param.set_raw_prediction_col('test_raw_prediction')
+        self.assertEqual(param.get_raw_prediction_col(), 'test_raw_prediction')
+
+    def test_reg_param(self):
+        param = TestParams()
+        reg = param.REG
+        self.assertEqual(reg.name, "reg")
+        self.assertEqual(reg.description, "Regularization parameter.")
+        self.assertEqual(reg.default_value, 0.)
+
+        param.set_reg(0.4)
+        self.assertEqual(param.get_reg(), 0.4)
+
+    def test_seed_param(self):
+        param = TestParams()
+        seed = param.SEED
+        self.assertEqual(seed.name, "seed")
+        self.assertEqual(seed.description, "The random seed.")
+        self.assertEqual(seed.default_value, None)
+
+        param.set_seed(1)
+        self.assertEqual(param.get_seed(), 1)
+
+    def test_tol(self):
+        param = TestParams()
+        tol = param.TOL
+        self.assertEqual(tol.name, "tol")
+        self.assertEqual(tol.description, "Convergence tolerance for iterative algorithms.")
+        self.assertEqual(tol.default_value, 1e-6)
+
+        param.set_tol(1e-5)
+        self.assertEqual(param.get_tol(), 1e-5)
+
+    def test_weight_col(self):
+        param = TestParams()
+        weight_col = param.WEIGHT_COL
+        self.assertEqual(weight_col.name, "weight_col")
+        self.assertEqual(weight_col.description, "Weight column name.")
+        self.assertEqual(weight_col.default_value, None)
+
+        param.set_weight_col('test_weight_col')
+        self.assertEqual(param.get_weight_col(), 'test_weight_col')