You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/01 05:55:25 UTC
[15/21] ignite git commit: IGNITE-9717: [ML] Add setters methods to
Logistic Regression and fix examples/tests
IGNITE-9717: [ML] Add setters methods to Logistic Regression and
fix examples/tests
this closes #4865
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/4da48e6f
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/4da48e6f
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/4da48e6f
Branch: refs/heads/ignite-gg-14206
Commit: 4da48e6f90ceb7ee585b66af4f384cc868f6ca8e
Parents: a373486
Author: zaleslaw <za...@gmail.com>
Authored: Fri Sep 28 16:05:39 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Sep 28 16:05:39 2018 +0300
----------------------------------------------------------------------
.../LogisticRegressionSGDTrainerExample.java | 16 ++++---
.../ml/tutorial/Step_1_Read_and_Learn.java | 2 +-
.../examples/ml/tutorial/Step_2_Imputing.java | 2 +-
.../examples/ml/tutorial/Step_3_Categorial.java | 2 +-
.../Step_3_Categorial_with_One_Hot_Encoder.java | 2 +-
.../ml/tutorial/Step_4_Add_age_fare.java | 2 +-
.../examples/ml/tutorial/Step_5_Scaling.java | 2 +-
.../tutorial/Step_5_Scaling_with_Pipeline.java | 2 +-
.../ignite/examples/ml/tutorial/Step_6_KNN.java | 2 +-
.../ml/tutorial/Step_7_Split_train_test.java | 2 +-
.../ignite/examples/ml/tutorial/Step_8_CV.java | 2 +-
.../ml/tutorial/Step_8_CV_with_Param_Grid.java | 2 +-
.../ml/tutorial/Step_9_Go_to_LogReg.java | 27 ++++++-----
.../ml/tutorial/TutorialStepByStepExample.java | 2 +-
.../binomial/LogisticRegressionSGDTrainer.java | 47 ++++++++++----------
.../LogRegressionMultiClassTrainer.java | 29 +++++++-----
.../SVMLinearBinaryClassificationTrainer.java | 2 +-
...VMLinearMultiClassClassificationTrainer.java | 2 +-
.../apache/ignite/ml/pipeline/PipelineTest.java | 18 +++-----
.../logistic/LogRegMultiClassTrainerTest.java | 1 -
.../logistic/LogisticRegressionModelTest.java | 17 +++----
.../LogisticRegressionSGDTrainerTest.java | 24 +++++-----
22 files changed, 111 insertions(+), 96 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 8d4218d..15330d0 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -60,11 +60,16 @@ public class LogisticRegressionSGDTrainerExample {
IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
System.out.println(">>> Create new logistic regression trainer object.");
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
System.out.println(">>> Perform the training to get the model.");
LogisticRegressionModel mdl = trainer.fit(
@@ -218,5 +223,4 @@ public class LogisticRegressionSGDTrainerExample {
{1, 5.1, 2.5, 3, 1.1},
{1, 5.7, 2.8, 4.1, 1.3},
};
-
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
index 264dbf4..481fa1d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
@@ -42,7 +42,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_1_Read_and_Learn {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 1 (read and learn) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
index df73235..d60dc4b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
@@ -44,7 +44,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_2_Imputing {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 2 (imputing) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
index 463a6ba..ac2fe08 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
@@ -47,7 +47,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_3_Categorial {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 3 (categorial) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
index 93e7e79..f0b6efe 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_3_Categorial_with_One_Hot_Encoder {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 3 (categorial with One-hot encoder) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
index bbeedb6..71e9efd 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
@@ -45,7 +45,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_4_Add_age_fare {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 4 (add age and fare) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
index 7d934d7..fe7bf91 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_5_Scaling {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 5 (scaling) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
index cc0a278..bd7cc21 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
@@ -48,7 +48,7 @@ import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
*/
public class Step_5_Scaling_with_Pipeline {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 5 (scaling) via Pipeline example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
index 0c8b562..a35b841 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
@@ -49,7 +49,7 @@ import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
*/
public class Step_6_KNN {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 6 (kNN) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
index c6d033c..53d4d0a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
@@ -51,7 +51,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_7_Split_train_test {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 7 (split to train and test) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
index d83e14a..feedccf 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
@@ -63,7 +63,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_8_CV {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 8 (cross-validation) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
index 594c0eb..670f025 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -65,7 +65,7 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
*/
public class Step_8_CV_with_Param_Grid {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
index 4e1e005..b98b0eb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
@@ -56,7 +56,7 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
*/
public class Step_9_Go_to_LogReg {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> Tutorial step 9 (logistic regression) example started.");
@@ -124,12 +124,13 @@ public class Step_9_Go_to_LogReg {
minMaxScalerPreprocessor
);
- LogisticRegressionSGDTrainer<?> trainer
- = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(learningRate),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), maxIterations, batchSize, locIterations, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(learningRate),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(maxIterations)
+ .withLocIterations(locIterations)
+ .withBatchSize(batchSize)
+ .withSeed(123L);
CrossValidation<LogisticRegressionModel, Double, Integer, Object[]>
scoreCalculator = new CrossValidation<>();
@@ -187,11 +188,13 @@ public class Step_9_Go_to_LogReg {
minMaxScalerPreprocessor
);
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(bestLearningRate),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), bestMaxIterations, bestBatchSize, bestLocIterations, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(bestLearningRate),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(bestMaxIterations)
+ .withLocIterations(bestLocIterations)
+ .withBatchSize(bestBatchSize)
+ .withSeed(123L);
System.out.println(">>> Perform the training to get the model.");
LogisticRegressionModel bestMdl = trainer.fit(
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
index 67f4bf5..a376ae6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TutorialStepByStepExample.java
@@ -23,7 +23,7 @@ package org.apache.ignite.examples.ml.tutorial;
*/
public class TutorialStepByStepExample {
/** Run examples with default settings. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
Step_1_Read_and_Learn.main(args);
Step_2_Imputing.main(args);
Step_3_Categorial.main(args);
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
index fb5d5a0..74a296d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -33,6 +33,8 @@ import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
@@ -41,37 +43,23 @@ import org.jetbrains.annotations.NotNull;
*/
public class LogisticRegressionSGDTrainer<P extends Serializable> extends SingleLabelDatasetTrainer<LogisticRegressionModel> {
/** Update strategy. */
- private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+ private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ );
/** Max number of iteration. */
- private int maxIterations;
+ private int maxIterations = 100;
/** Batch size. */
- private int batchSize;
+ private int batchSize = 100;
/** Number of local iterations. */
- private int locIterations;
+ private int locIterations = 100;
/** Seed for random generator. */
- private long seed;
-
- /**
- * Constructs a new instance of linear regression SGD trainer.
- *
- * @param updatesStgy Update strategy.
- * @param maxIterations Max number of iteration.
- * @param batchSize Batch size.
- * @param locIterations Number of local iterations.
- * @param seed Seed for random generator.
- */
- public LogisticRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy, int maxIterations,
- int batchSize, int locIterations, long seed) {
- this.updatesStgy = updatesStgy;
- this.maxIterations = maxIterations;
- this.batchSize = batchSize;
- this.locIterations = locIterations;
- this.seed = seed;
- }
+ private long seed = 1234L;
/** {@inheritDoc} */
@Override public <K, V> LogisticRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
@@ -202,11 +190,22 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
}
/**
+ * Set up the regularization parameter.
+ *
+ * @param updatesStgy Update strategy.
+ * @return Trainer with new update strategy parameter value.
+ */
+ public LogisticRegressionSGDTrainer withUpdatesStgy(UpdatesStrategy updatesStgy) {
+ this.updatesStgy = updatesStgy;
+ return this;
+ }
+
+ /**
* Get the update strategy.
*
* @return The property value.
*/
- public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
+ public UpdatesStrategy getUpdatesStgy() {
return updatesStgy;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
index b9cdcc7..71d54fa 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/multiclass/LogRegressionMultiClassTrainer.java
@@ -32,8 +32,9 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
@@ -46,19 +47,23 @@ import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
public class LogRegressionMultiClassTrainer<P extends Serializable>
extends SingleLabelDatasetTrainer<LogRegressionMultiClassModel> {
/** Update strategy. */
- private UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy;
+ private UpdatesStrategy updatesStgy = new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ );
/** Max number of iteration. */
- private int amountOfIterations;
+ private int amountOfIterations = 100;
/** Batch size. */
- private int batchSize;
+ private int batchSize = 100;
/** Number of local iterations. */
- private int amountOfLocIterations;
+ private int amountOfLocIterations = 100;
/** Seed for random generator. */
- private long seed;
+ private long seed = 1234L;
/**
* Trains model based on the specified data.
@@ -90,7 +95,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
classes.forEach(clsLb -> {
LogisticRegressionSGDTrainer<?> trainer =
- new LogisticRegressionSGDTrainer<>(updatesStgy, amountOfIterations, batchSize, amountOfLocIterations, seed);
+ new LogisticRegressionSGDTrainer<>()
+ .withBatchSize(batchSize)
+ .withLocIterations(amountOfLocIterations)
+ .withMaxIterations(amountOfIterations)
+ .withSeed(seed);
IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
Double lb = lbExtractor.apply(k, v);
@@ -238,7 +247,7 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
}
/**
- * Set up the regularization parameter.
+ * Set up the updates strategy.
*
* @param updatesStgy Update strategy.
* @return Trainer with new update strategy parameter value.
@@ -249,11 +258,11 @@ public class LogRegressionMultiClassTrainer<P extends Serializable>
}
/**
- * Get the update strategy..
+ * Get the update strategy.
*
* @return The parameter value.
*/
- public UpdatesStrategy<? super MultilayerPerceptron, P> getUpdatesStgy() {
+ public UpdatesStrategy getUpdatesStgy() {
return updatesStgy;
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 2c621c8..47666f4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -50,7 +50,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
private double lambda = 0.4;
/** The seed number. */
- private long seed;
+ private long seed = 1234L;
/**
* Trains model based on the specified data.
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index ec60034..b161914 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -52,7 +52,7 @@ public class SVMLinearMultiClassClassificationTrainer
private double lambda = 0.2;
/** The seed number. */
- private long seed;
+ private long seed = 1234L;
/**
* Trains model based on the specified data.
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
index 91bbcd4..d517ce6 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
@@ -51,11 +51,13 @@ public class PipelineTest extends TrainerTest {
cacheMock.put(i, convertedRow);
}
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator().withLearningRate(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
.addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
@@ -88,12 +90,6 @@ public class PipelineTest extends TrainerTest {
cacheMock.put(i, convertedRow);
}
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator().withLearningRate(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
-
PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
.addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
.addLabelExtractor((k, v) -> v[0])
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
index 78cd08d..c99bf02 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogRegMultiClassTrainerTest.java
@@ -133,7 +133,6 @@ public class LogRegMultiClassTrainerTest extends TrainerTest {
VectorUtils.of(10, -10)
);
-
for (Vector vec : vectors) {
TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
index 89c9cca..e8aaacd 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModelTest.java
@@ -38,7 +38,7 @@ public class LogisticRegressionModelTest {
/** */
@Test
public void testPredict() {
- Vector weights = new DenseVector(new double[]{2.0, 3.0});
+ Vector weights = new DenseVector(new double[] {2.0, 3.0});
assertFalse(new LogisticRegressionModel(weights, 1.0).isKeepingRawLabels());
@@ -57,35 +57,36 @@ public class LogisticRegressionModelTest {
/** */
@Test(expected = CardinalityException.class)
public void testPredictOnAnObservationWithWrongCardinality() {
- Vector weights = new DenseVector(new double[]{2.0, 3.0});
+ Vector weights = new DenseVector(new double[] {2.0, 3.0});
LogisticRegressionModel mdl = new LogisticRegressionModel(weights, 1.0);
- Vector observation = new DenseVector(new double[]{1.0});
+ Vector observation = new DenseVector(new double[] {1.0});
mdl.apply(observation);
}
/** */
private void verifyPredict(LogisticRegressionModel mdl) {
- Vector observation = new DenseVector(new double[]{1.0, 1.0});
+ Vector observation = new DenseVector(new double[] {1.0, 1.0});
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
- observation = new DenseVector(new double[]{2.0, 1.0});
+ observation = new DenseVector(new double[] {2.0, 1.0});
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
- observation = new DenseVector(new double[]{1.0, 2.0});
+ observation = new DenseVector(new double[] {1.0, 2.0});
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 + 3.0 * 2.0), mdl.apply(observation), PRECISION);
- observation = new DenseVector(new double[]{-2.0, 1.0});
+ observation = new DenseVector(new double[] {-2.0, 1.0});
TestUtils.assertEquals(sigmoid(1.0 - 2.0 * 2.0 + 3.0 * 1.0), mdl.apply(observation), PRECISION);
- observation = new DenseVector(new double[]{1.0, -2.0});
+ observation = new DenseVector(new double[] {1.0, -2.0});
TestUtils.assertEquals(sigmoid(1.0 + 2.0 * 1.0 - 3.0 * 2.0), mdl.apply(observation), PRECISION);
}
/**
* Sigmoid function.
+ *
* @param z The regression value.
* @return The result.
*/
http://git-wip-us.apache.org/repos/asf/ignite/blob/4da48e6f/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
index 723677c..d9b6f7a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainerTest.java
@@ -45,11 +45,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
cacheMock.put(i, twoLinearlySeparableClasses[i]);
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator().withLearningRate(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
LogisticRegressionModel mdl = trainer.fit(
cacheMock,
@@ -70,11 +72,13 @@ public class LogisticRegressionSGDTrainerTest extends TrainerTest {
for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
cacheMock.put(i, twoLinearlySeparableClasses[i]);
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator().withLearningRate(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
LogisticRegressionModel originalMdl = trainer.fit(
cacheMock,