You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/01 05:55:25 UTC

[15/21] ignite git commit: IGNITE-9717: [ML] Add setters methods to Logistic Regression and fix examples/tests

IGNITE-9717: [ML] Add setters methods to Logistic Regression and
fix examples/tests

this closes #4865


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4da48e6f
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4da48e6f
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4da48e6f

Branch: refs/heads/ignite-gg-14206
Commit: 4da48e6f90ceb7ee585b66af4f384cc868f6ca8e
Parents: a373486
Author: zaleslaw <za...@gmail.com>
Authored: Fri Sep 28 16:05:39 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Sep 28 16:05:39 2018 +0300

----------------------------------------------------------------------
 .../LogisticRegressionSGDTrainerExample.java    | 16 ++++---
 .../ml/tutorial/Step_1_Read_and_Learn.java      |  2 +-
 .../examples/ml/tutorial/Step_2_Imputing.java   |  2 +-
 .../examples/ml/tutorial/Step_3_Categorial.java |  2 +-
 .../Step_3_Categorial_with_One_Hot_Encoder.java |  2 +-
 .../ml/tutorial/Step_4_Add_age_fare.java        |  2 +-
 .../examples/ml/tutorial/Step_5_Scaling.java    |  2 +-
 .../tutorial/Step_5_Scaling_with_Pipeline.java  |  2 +-
 .../ignite/examples/ml/tutorial/Step_6_KNN.java |  2 +-
 .../ml/tutorial/Step_7_Split_train_test.java    |  2 +-
 .../ignite/examples/ml/tutorial/Step_8_CV.java  |  2 +-
 .../ml/tutorial/Step_8_CV_with_Param_Grid.java  |  2 +-
 .../ml/tutorial/Step_9_Go_to_LogReg.java        | 27 ++++++-----
 .../ml/tutorial/TutorialStepByStepExample.java  |  2 +-
 .../binomial/LogisticRegressionSGDTrainer.java  | 47 ++++++++++----------
 .../LogRegressionMultiClassTrainer.java         | 29 +++++++-----
 .../SVMLinearBinaryClassificationTrainer.java   |  2 +-
 ...VMLinearMultiClassClassificationTrainer.java |  2 +-
 .../apache/ignite/ml/pipeline/PipelineTest.java | 18 +++-----
 .../logistic/LogRegMultiClassTrainerTest.java   |  1 -
 .../logistic/LogisticRegressionModelTest.java   | 17 +++----
 .../LogisticRegressionSGDTrainerTest.java       | 24 +++++-----
 22 files changed, 111 insertions(+), 96 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 8d4218d..15330d0 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -60,11 +60,16 @@ public class LogisticRegressionSGDTrainerExample {
             IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
 
             System.out.println(">>> Create new logistic regression trainer object.");
-            LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-                new SimpleGDUpdateCalculator(0.2),
-                SimpleGDParameterUpdate::sumLocal,
-                SimpleGDParameterUpdate::avg
-            ), 100000,  10, 100, 123L);
+            LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                .withUpdatesStgy(new UpdatesStrategy<>(
+                    new SimpleGDUpdateCalculator(0.2),
+                    SimpleGDParameterUpdate::sumLocal,
+                    SimpleGDParameterUpdate::avg
+                ))
+                .withMaxIterations(100000)
+                .withLocIterations(100)
+                .withBatchSize(10)
+                .withSeed(123L);
 
             System.out.println(">>> Perform the training to get the model.");
             LogisticRegressionModel mdl = trainer.fit(
@@ -218,5 +223,4 @@ public class LogisticRegressionSGDTrainerExample {
         {1, 5.1, 2.5, 3, 1.1},
         {1, 5.7, 2.8, 4.1, 1.3},
     };
-
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
index 264dbf4..481fa1d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
@@ -42,7 +42,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_1_Read_and_Learn {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 1 (read and learn) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
index df73235..d60dc4b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
@@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_2_Imputing {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 2 (imputing) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
index 463a6ba..ac2fe08 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
@@ -47,7 +47,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_3_Categorial {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 3 (categorial) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
index 93e7e79..f0b6efe 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_3_Categorial_with_One_Hot_Encoder {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 3 (categorial with One-hot encoder) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
index bbeedb6..71e9efd 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
@@ -45,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_4_Add_age_fare {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 4 (add age and fare) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
index 7d934d7..fe7bf91 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_5_Scaling {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 5 (scaling) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
index cc0a278..bd7cc21 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
  */
 public class Step_5_Scaling_with_Pipeline {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 5 (scaling) via Pipeline example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
index 0c8b562..a35b841 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
@@ -49,7 +49,7 @@ import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
  */
 public class Step_6_KNN {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 6 (kNN) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
index c6d033c..53d4d0a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
@@ -51,7 +51,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_7_Split_train_test {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 7 (split to train and test) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
index d83e14a..feedccf 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
@@ -63,7 +63,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_8_CV {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 8 (cross-validation) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
index 594c0eb..670f025 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -65,7 +65,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
  */
 public class Step_8_CV_with_Param_Grid {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started.");
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
index 4e1e005..b98b0eb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
@@ -56,7 +56,7 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
  */
 public class Step_9_Go_to_LogReg {
     /** Run example. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         System.out.println();
         System.out.println(">>> Tutorial step 9 (logistic regression) example started.");
 
@@ -124,12 +124,13 @@ public class Step_9_Go_to_LogReg {
                                             minMaxScalerPreprocessor
                                         );
 
-                                    LogisticRegressionSGDTrainer<?> trainer
-                                        = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-                                        new SimpleGDUpdateCalculator(learningRate),
-                                        SimpleGDParameterUpdate::sumLocal,
-                                        SimpleGDParameterUpdate::avg
-                                    ), maxIterations, batchSize, locIterations, 123L);
+                                    LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                                        .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(learningRate),
+                                            SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+                                        .withMaxIterations(maxIterations)
+                                        .withLocIterations(locIterations)
+                                        .withBatchSize(batchSize)
+                                        .withSeed(123L);
 
                                     CrossValidation<LogisticRegressionModel, Double, Integer, Object[]>
                                         scoreCalculator = new CrossValidation<>();
@@ -187,11 +188,13 @@ public class Step_9_Go_to_LogReg {
                         minMaxScalerPreprocessor
                     );
 
-                LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-                    new SimpleGDUpdateCalculator(bestLearningRate),
-                    SimpleGDParameterUpdate::sumLocal,
-                    SimpleGDParameterUpdate::avg
-                ), bestMaxIterations,  bestBatchSize, bestLocIterations, 123L);
+                LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+                    .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(bestLearningRate),
+                        SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+                    .withMaxIterations(bestMaxIterations)
+                    .withLocIterations(bestLocIterations)
+                    .withBatchSize(bestBatchSize)
+                    .withSeed(123L);
 
                 System.out.println(">>> Perform the training to get the model.");
                 LogisticRegressionModel bestMdl = trainer.fit(

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
index 67f4bf5..a376ae6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
@@ -23,7 +23,7 @@ package org.apache.ignite.examples.ml.tutorial;
  */
 public class TutorialStepByStepExample {
     /** Run examples with default settings. */
-    public static void main(String[] args) throws InterruptedException {
+    public static void main(String[] args) {
         Step_1_Read_and_Learn.main(args);
         Step_2_Imputing.main(args);
         Step_3_Categorial.main(args);

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
index fb5d5a0..74a296d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -33,6 +33,8 @@ import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 import org.jetbrains.annotations.NotNull;
 
@@ -41,37 +43,23 @@ import org.jetbrains.annotations.NotNull;
  */
 public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> {
     /** Update strategy. */
-    private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+    private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
+        new SimpleGDUpdateCalculator(0.2),
+        SimpleGDParameterUpdate::sumLocal,
+        SimpleGDParameterUpdate::avg
+    );
 
     /** Max number of iteration. */
-    private int maxIterations;
+    private int maxIterations = 100;
 
     /** Batch size. */
-    private int batchSize;
+    private int batchSize = 100;
 
     /** Number of local iterations. */
-    private int locIterations;
+    private int locIterations = 100;
 
     /** Seed for random generator. */
-    private long seed;
-
-    /**
-     * Constructs a new instance of linear regression SGD trainer.
-     *
-     * @param updatesStgy Update strategy.
-     * @param maxIterations Max number of iteration.
-     * @param batchSize Batch size.
-     * @param locIterations Number of local iterations.
-     * @param seed Seed for random generator.
-     */
-    public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
-        int batchSize, int locIterations, long seed) {
-        this.updatesStgy = updatesStgy;
-        this.maxIterations = maxIterations;
-        this.batchSize = batchSize;
-        this.locIterations = locIterations;
-        this.seed = seed;
-    }
+    private long seed = 1234L;
 
     /** {@inheritDoc} */
     @Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
@@ -202,11 +190,22 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
     }
 
     /**
+     * Set up the regularization parameter.
+     *
+     * @param updatesStgy Update strategy.
+     * @return Trainer with new update strategy parameter value.
+     */
+    public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
+        this.updatesStgy = updatesStgy;
+        return this;
+    }
+
+    /**
      * Get the update strategy.
      *
      * @return The property value.
      */
-    public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
+    public UpdatesStrategy getUpdatesStgy() {
         return updatesStgy;
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
index b9cdcc7..71d54fa 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
@@ -32,8 +32,9 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
 import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
 import org.apache.ignite.ml.math.functions.IgniteBiFunction;
 import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
 import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
 import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
@@ -46,19 +47,23 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
 public class LogRegressionMultiClassTrainer<P extends Serializable>
     extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel> {
     /** Update strategy. */
-    private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+    private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
+        new SimpleGDUpdateCalculator(0.2),
+        SimpleGDParameterUpdate::sumLocal,
+        SimpleGDParameterUpdate::avg
+    );
 
     /** Max number of iteration. */
-    private int amountOfIterations;
+    private int amountOfIterations = 100;
 
     /** Batch size. */
-    private int batchSize;
+    private int batchSize = 100;
 
     /** Number of local iterations. */
-    private int amountOfLocIterations;
+    private int amountOfLocIterations = 100;
 
     /** Seed for random generator. */
-    private long seed;
+    private long seed = 1234L;
 
     /**
      * Trains model based on the specified data.
@@ -90,7 +95,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
 
         classes.forEach(clsLb -> {
             LogisticRegressionSGDTrainer<?> trainer =
-                new LogisticRegressionSGDTrainer<>(updatesStgy, amountOfIterations, batchSize, amountOfLocIterations, seed);
+                new LogisticRegressionSGDTrainer<>()
+                    .withBatchSize(batchSize)
+                    .withLocIterations(amountOfLocIterations)
+                    .withMaxIterations(amountOfIterations)
+                    .withSeed(seed);
 
             IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
                 Double lb = lbExtractor.apply(k, v);
@@ -238,7 +247,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
     }
 
     /**
-     * Set up the regularization parameter.
+     * Set up the updates strategy.
      *
      * @param updatesStgy Update strategy.
      * @return Trainer with new update strategy parameter value.
@@ -249,11 +258,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
     }
 
     /**
-     * Get the update strategy..
+     * Get the update strategy.
      *
      * @return The parameter value.
      */
-    public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
+    public UpdatesStrategy getUpdatesStgy() {
         return updatesStgy;
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 2c621c8..47666f4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -50,7 +50,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
     private double lambda = 0.4;
 
     /** The seed number. */
-    private long seed;
+    private long seed = 1234L;
 
     /**
      * Trains model based on the specified data.

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index ec60034..b161914 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -52,7 +52,7 @@ public class SVMLinearMultiClassClassificationTrainer
     private double lambda = 0.2;
 
     /** The seed number. */
-    private long seed;
+    private long seed = 1234L;
 
     /**
      * Trains model based on the specified data.

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
index 91bbcd4..d517ce6 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
@@ -51,11 +51,13 @@ public class PipelineTest extends TrainerTest {
             cacheMock.put(i, convertedRow);
         }
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator().withLearningRate(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        ), 100000, 10, 100, 123L);
+        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+            .withMaxIterations(100000)
+            .withLocIterations(100)
+            .withBatchSize(10)
+            .withSeed(123L);
 
         PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
             .addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
@@ -88,12 +90,6 @@ public class PipelineTest extends TrainerTest {
             cacheMock.put(i, convertedRow);
         }
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator().withLearningRate(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        ), 100000, 10, 100, 123L);
-
         PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
             .addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
             .addLabelExtractor((k, v) -> v[0])

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
index 78cd08d..c99bf02 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
@@ -133,7 +133,6 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
             VectorUtils.of(10, -10)
         );
 
-
         for (Vector vec : vectors) {
             TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
             TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
index 89c9cca..e8aaacd 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
@@ -38,7 +38,7 @@ public class LogisticRegressionModelTest {
     /** */
     @Test
     public void testPredict() {
-        Vector weights = new DenseVector(new double[]{2.0, 3.0});
+        Vector weights = new DenseVector(new double[] {2.0, 3.0});
 
         assertFalse(new LogisticRegressionModel(weights, 1.0).isKeepingRawLabels());
 
@@ -57,35 +57,36 @@ public class LogisticRegressionModelTest {
     /** */
     @Test(expected = CardinalityException.class)
     public void testPredictOnAnObservationWithWrongCardinality() {
-        Vector weights = new DenseVector(new double[]{2.0, 3.0});
+        Vector weights = new DenseVector(new double[] {2.0, 3.0});
 
         LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0);
 
-        Vector observation = new DenseVector(new double[]{1.0});
+        Vector observation = new DenseVector(new double[] {1.0});
 
         mdl.apply(observation);
     }
 
     /** */
     private void verifyPredict(LogisticRegressionModel mdl) {
-        Vector observation = new DenseVector(new double[]{1.0, 1.0});
+        Vector observation = new DenseVector(new double[] {1.0, 1.0});
         TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
 
-        observation = new DenseVector(new double[]{2.0, 1.0});
+        observation = new DenseVector(new double[] {2.0, 1.0});
         TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
 
-        observation = new DenseVector(new double[]{1.0, 2.0});
+        observation = new DenseVector(new double[] {1.0, 2.0});
         TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION);
 
-        observation = new DenseVector(new double[]{-2.0, 1.0});
+        observation = new DenseVector(new double[] {-2.0, 1.0});
         TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
 
-        observation = new DenseVector(new double[]{1.0, -2.0});
+        observation = new DenseVector(new double[] {1.0, -2.0});
         TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION);
     }
 
     /**
      * Sigmoid function.
+     *
      * @param z The regression value.
      * @return The result.
      */

http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
index 723677c..d9b6f7a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -45,11 +45,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator().withLearningRate(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        ), 100000, 10, 100, 123L);
+        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+            .withMaxIterations(100000)
+            .withLocIterations(100)
+            .withBatchSize(10)
+            .withSeed(123L);
 
         LogisticRegressionModel mdl = trainer.fit(
             cacheMock,
@@ -70,11 +72,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
         for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
             cacheMock.put(i, twoLinearlySeparableClasses[i]);
 
-        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
-            new SimpleGDUpdateCalculator().withLearningRate(0.2),
-            SimpleGDParameterUpdate::sumLocal,
-            SimpleGDParameterUpdate::avg
-        ), 100000, 10, 100, 123L);
+        LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+            .withMaxIterations(100000)
+            .withLocIterations(100)
+            .withBatchSize(10)
+            .withSeed(123L);
 
         LogisticRegressionModel originalMdl = trainer.fit(
             cacheMock,