You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ag...@apache.org on 2018/10/04 16:03:15 UTC
[07/50] [abbrv] ignite git commit: IGNITE-9711: [ML] Remove
IgniteThread wrapper from ml examples
IGNITE-9711: [ML] Remove IgniteThread wrapper from ml examples
this closes #4849
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/609266fe
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/609266fe
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/609266fe
Branch: refs/heads/ignite-5797
Commit: 609266fe2797c07599a893625f933740a25d049d
Parents: c7227cf
Author: YuriBabak <y....@gmail.com>
Authored: Fri Sep 28 11:57:58 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Sep 28 11:57:58 2018 +0300
----------------------------------------------------------------------
.../clustering/KMeansClusterizationExample.java | 75 +++--
.../ml/knn/ANNClassificationExample.java | 97 +++---
.../ml/knn/KNNClassificationExample.java | 69 ++---
.../examples/ml/knn/KNNRegressionExample.java | 79 +++--
.../examples/ml/nn/MLPTrainerExample.java | 130 ++++----
.../LinearRegressionLSQRTrainerExample.java | 56 ++--
...ssionLSQRTrainerWithMinMaxScalerExample.java | 69 ++---
.../LinearRegressionSGDTrainerExample.java | 65 ++--
.../LogisticRegressionSGDTrainerExample.java | 84 +++---
...gressionMultiClassClassificationExample.java | 169 +++++------
.../ml/selection/cv/CrossValidationExample.java | 58 ++--
.../split/TrainTestDatasetSplitterExample.java | 69 ++---
.../binary/SVMBinaryClassificationExample.java | 79 +++--
.../SVMMultiClassClassificationExample.java | 151 +++++-----
...ecisionTreeClassificationTrainerExample.java | 74 ++---
.../DecisionTreeRegressionTrainerExample.java | 63 ++--
.../GDBOnTreesClassificationTrainerExample.java | 58 ++--
.../GDBOnTreesRegressionTrainerExample.java | 55 ++--
.../RandomForestClassificationExample.java | 76 +++--
.../RandomForestRegressionExample.java | 91 +++---
.../ml/tutorial/Step_1_Read_and_Learn.java | 61 ++--
.../examples/ml/tutorial/Step_2_Imputing.java | 71 ++---
.../examples/ml/tutorial/Step_3_Categorial.java | 96 +++---
.../Step_3_Categorial_with_One_Hot_Encoder.java | 98 +++---
.../ml/tutorial/Step_4_Add_age_fare.java | 98 +++---
.../examples/ml/tutorial/Step_5_Scaling.java | 125 ++++----
.../tutorial/Step_5_Scaling_with_Pipeline.java | 77 +++--
.../ignite/examples/ml/tutorial/Step_6_KNN.java | 127 ++++----
.../ml/tutorial/Step_7_Split_train_test.java | 136 ++++-----
.../ignite/examples/ml/tutorial/Step_8_CV.java | 218 +++++++-------
.../ml/tutorial/Step_8_CV_with_Param_Grid.java | 200 ++++++-------
.../ml/tutorial/Step_9_Go_to_LogReg.java | 296 +++++++++----------
32 files changed, 1507 insertions(+), 1763 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
index 152375a..567775b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
@@ -30,7 +30,6 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run KMeans clustering algorithm ({@link KMeansTrainer}) over distributed dataset.
@@ -55,58 +54,52 @@ public class KMeansClusterizationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- KMeansClusterizationExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- KMeansTrainer trainer = new KMeansTrainer()
- .withSeed(7867L);
+ KMeansTrainer trainer = new KMeansTrainer()
+ .withSeed(7867L);
- KMeansModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ KMeansModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> KMeans centroids");
- Tracer.showAscii(mdl.getCenters()[0]);
- Tracer.showAscii(mdl.getCenters()[1]);
- System.out.println(">>>");
+ System.out.println(">>> KMeans centroids");
+ Tracer.showAscii(mdl.getCenters()[0]);
+ Tracer.showAscii(mdl.getCenters()[1]);
+ System.out.println(">>>");
- System.out.println(">>> -----------------------------------");
- System.out.println(">>> | Predicted cluster\t| Real Label\t|");
- System.out.println(">>> -----------------------------------");
+ System.out.println(">>> -----------------------------------");
+ System.out.println(">>> | Predicted cluster\t| Real Label\t|");
+ System.out.println(">>> -----------------------------------");
- int amountOfErrors = 0;
- int totalAmount = 0;
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(new DenseVector(inputs));
- totalAmount++;
- if (groundTruth != prediction)
- amountOfErrors++;
+ totalAmount++;
+ if (groundTruth != prediction)
+ amountOfErrors++;
- System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
+ System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
- }
- });
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
index 8a2d786..c9490fc 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
@@ -34,7 +34,6 @@ import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run ANN multi-class classification trainer ({@link ANNClassificationTrainer}) over distributed dataset.
@@ -59,73 +58,67 @@ public class ANNClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- ANNClassificationExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
- ANNClassificationTrainer trainer = new ANNClassificationTrainer()
- .withDistance(new ManhattanDistance())
- .withK(50)
- .withMaxIterations(1000)
- .withSeed(1234L)
- .withEpsilon(1e-2);
+ ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+ .withDistance(new ManhattanDistance())
+ .withK(50)
+ .withMaxIterations(1000)
+ .withSeed(1234L)
+ .withEpsilon(1e-2);
- long startTrainingTime = System.currentTimeMillis();
+ long startTrainingTime = System.currentTimeMillis();
- NNClassificationModel knnMdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- ).withK(5)
- .withDistanceMeasure(new EuclideanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ NNClassificationModel knnMdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ ).withK(5)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- long endTrainingTime = System.currentTimeMillis();
+ long endTrainingTime = System.currentTimeMillis();
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int amountOfErrors = 0;
- int totalAmount = 0;
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- long totalPredictionTime = 0L;
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ long totalPredictionTime = 0L;
- long startPredictionTime = System.currentTimeMillis();
- double prediction = knnMdl.apply(new DenseVector(inputs));
- long endPredictionTime = System.currentTimeMillis();
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- totalPredictionTime += (endPredictionTime - startPredictionTime);
+ long startPredictionTime = System.currentTimeMillis();
+ double prediction = knnMdl.apply(new DenseVector(inputs));
+ long endPredictionTime = System.currentTimeMillis();
- totalAmount++;
- if (groundTruth != prediction)
- amountOfErrors++;
+ totalPredictionTime += (endPredictionTime - startPredictionTime);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ totalAmount++;
+ if (groundTruth != prediction)
+ amountOfErrors++;
- System.out.println(">>> ---------------------------------");
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
- System.out.println("Prediction costs = " + totalPredictionTime);
+ System.out.println(">>> ---------------------------------");
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
- System.out.println(totalAmount);
+ System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
+ System.out.println("Prediction costs = " + totalPredictionTime);
- System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
- }
- });
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+ System.out.println(totalAmount);
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index cf285a4..5cbb2ad 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -31,7 +31,6 @@ import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset.
@@ -56,54 +55,48 @@ public class KNNClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- KNNClassificationExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- NNClassificationModel knnMdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- ).withK(3)
- .withDistanceMeasure(new EuclideanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ NNClassificationModel knnMdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int amountOfErrors = 0;
- int totalAmount = 0;
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- double prediction = knnMdl.apply(new DenseVector(inputs));
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- totalAmount++;
- if (groundTruth != prediction)
- amountOfErrors++;
+ double prediction = knnMdl.apply(new DenseVector(inputs));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ totalAmount++;
+ if (groundTruth != prediction)
+ amountOfErrors++;
- System.out.println(">>> ---------------------------------");
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
- }
- });
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example completed.");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index 78f38c8..3969f0c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -31,7 +31,6 @@ import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
@@ -57,61 +56,55 @@ public class KNNRegressionExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- KNNRegressionExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
- KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- ).withK(5)
- .withDistanceMeasure(new ManhattanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ ).withK(5)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int totalAmount = 0;
- // Calculate mean squared error (MSE)
- double mse = 0.0;
- // Calculate mean absolute error (MAE)
- double mae = 0.0;
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ int totalAmount = 0;
+ // Calculate mean squared error (MSE)
+ double mse = 0.0;
+ // Calculate mean absolute error (MAE)
+ double mae = 0.0;
- double prediction = knnMdl.apply(new DenseVector(inputs));
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- mse += Math.pow(prediction - groundTruth, 2.0);
- mae += Math.abs(prediction - groundTruth);
+ double prediction = knnMdl.apply(new DenseVector(inputs));
- totalAmount++;
+ mse += Math.pow(prediction - groundTruth, 2.0);
+ mae += Math.abs(prediction - groundTruth);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ totalAmount++;
- System.out.println(">>> ---------------------------------");
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- mse = mse / totalAmount;
- System.out.println("\n>>> Mean squared error (MSE) " + mse);
+ System.out.println(">>> ---------------------------------");
- mae = mae / totalAmount;
- System.out.println("\n>>> Mean absolute error (MAE) " + mae);
+ mse = mse / totalAmount;
+ System.out.println("\n>>> Mean squared error (MSE) " + mse);
- System.out.println(">>> kNN regression over cached dataset usage example completed.");
- }
- });
+ mae = mae / totalAmount;
+ System.out.println("\n>>> Mean absolute error (MAE) " + mae);
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> kNN regression over cached dataset usage example completed.");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
index 3e5a98c..6d5745e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
@@ -34,7 +34,6 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.thread.IgniteThread;
/**
* Example of using distributed {@link MultilayerPerceptron}.
@@ -70,76 +69,65 @@ public class MLPTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- MLPTrainerExample.class.getSimpleName(), () -> {
-
- // Create cache with training data.
- CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
- trainingSetCfg.setName("TRAINING_SET");
- trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
-
- IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
-
- // Fill cache with training data.
- trainingSet.put(0, new LabeledPoint(0, 0, 0));
- trainingSet.put(1, new LabeledPoint(0, 1, 1));
- trainingSet.put(2, new LabeledPoint(1, 0, 1));
- trainingSet.put(3, new LabeledPoint(1, 1, 0));
-
- // Define a layered architecture.
- MLPArchitecture arch = new MLPArchitecture(2).
- withAddedLayer(10, true, Activators.RELU).
- withAddedLayer(1, false, Activators.SIGMOID);
-
- // Define a neural network trainer.
- MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
- arch,
- LossFunctions.MSE,
- new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.1),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ),
- 3000,
- 4,
- 50,
- 123L
- );
-
- // Train neural network and get multilayer perceptron model.
- MultilayerPerceptron mlp = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> new double[] {v.lb}
- );
-
- int totalCnt = 4;
- int failCnt = 0;
-
- // Calculate score.
- for (int i = 0; i < 4; i++) {
- LabeledPoint pnt = trainingSet.get(i);
- Matrix predicted = mlp.apply(new DenseMatrix(new double[][] {{pnt.x, pnt.y}}));
-
- double predictedVal = predicted.get(0, 0);
- double lbl = pnt.lb;
- System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl);
- failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1;
- }
-
- double failRatio = (double)failCnt / totalCnt;
-
- System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
-
- System.out.println("\n>>> Distributed multilayer perceptron example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ // Create cache with training data.
+ CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ trainingSetCfg.setName("TRAINING_SET");
+ trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+
+ // Fill cache with training data.
+ trainingSet.put(0, new LabeledPoint(0, 0, 0));
+ trainingSet.put(1, new LabeledPoint(0, 1, 1));
+ trainingSet.put(2, new LabeledPoint(1, 0, 1));
+ trainingSet.put(3, new LabeledPoint(1, 1, 0));
+
+ // Define a layered architecture.
+ MLPArchitecture arch = new MLPArchitecture(2).
+ withAddedLayer(10, true, Activators.RELU).
+ withAddedLayer(1, false, Activators.SIGMOID);
+
+ // Define a neural network trainer.
+ MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
+ arch,
+ LossFunctions.MSE,
+ new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.1),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ),
+ 3000,
+ 4,
+ 50,
+ 123L
+ );
+
+ // Train neural network and get multilayer perceptron model.
+ MultilayerPerceptron mlp = trainer.fit(
+ ignite,
+ trainingSet,
+ (k, v) -> VectorUtils.of(v.x, v.y),
+ (k, v) -> new double[] {v.lb}
+ );
+
+ int totalCnt = 4;
+ int failCnt = 0;
+
+ // Calculate score.
+ for (int i = 0; i < 4; i++) {
+ LabeledPoint pnt = trainingSet.get(i);
+ Matrix predicted = mlp.apply(new DenseMatrix(new double[][] {{pnt.x, pnt.y}}));
+
+ double predictedVal = predicted.get(0, 0);
+ double lbl = pnt.lb;
+ System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl);
+ failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1;
+ }
+
+ double failRatio = (double)failCnt / totalCnt;
+
+ System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
+ System.out.println("\n>>> Distributed multilayer perceptron example completed.");
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 6ac445c..862a37f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -29,7 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
@@ -110,47 +109,40 @@ public class LinearRegressionLSQRTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LinearRegressionLSQRTrainerExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ System.out.println(">>> Perform the training to get the model.");
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(new DenseVector(inputs));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
+ }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
}
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
index 320d464..5692cb3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
@@ -32,7 +32,6 @@ import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
@@ -116,55 +115,47 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LinearRegressionLSQRTrainerWithMinMaxScalerExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
+ IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
- System.out.println(">>> Create new minmaxscaling trainer object.");
- MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
+ System.out.println(">>> Create new minmaxscaling trainer object.");
+ MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
- System.out.println(">>> Perform the training to get the minmaxscaling preprocessor.");
- IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
- ignite,
- dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- }
- );
+ System.out.println(">>> Perform the training to get the minmaxscaling preprocessor.");
+ IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> {
+ double[] arr = v.asArray();
+ return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
+ }
+ );
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0));
+ System.out.println(">>> Perform the training to get the model.");
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, (k, v) -> v.get(0));
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Integer key = observation.getKey();
- Vector val = observation.getValue();
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Integer key = observation.getKey();
+ Vector val = observation.getValue();
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(preprocessor.apply(key, val));
+ double prediction = mdl.apply(preprocessor.apply(key, val));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
+ }
- System.out.println(">>> ---------------------------------");
-
- System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cache based dataset usage example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cache based dataset usage example completed.");
}
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index 9fdc0df..1e9bd5a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -32,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run linear regression model based on based on
@@ -114,52 +113,44 @@ public class LinearRegressionSGDTrainerExample {
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LinearRegressionSGDTrainerExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new RPropUpdateCalculator(),
- RPropParameterUpdate::sumLocal,
- RPropParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new RPropUpdateCalculator(),
+ RPropParameterUpdate::sumLocal,
+ RPropParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ System.out.println(">>> Perform the training to get the model.");
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(new DenseVector(inputs));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
+ }
- System.out.println(">>> ---------------------------------");
-
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
}
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 0a6ff01..8d4218d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -32,7 +32,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpda
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run logistic regression model based on <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">
@@ -57,69 +56,62 @@ public class LogisticRegressionSGDTrainerExample {
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LogisticRegressionSGDTrainerExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- System.out.println(">>> Create new logistic regression trainer object.");
- LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ), 100000, 10, 100, 123L);
- System.out.println(">>> Perform the training to get the model.");
- LogisticRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ System.out.println(">>> Perform the training to get the model.");
+ LogisticRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> Logistic regression model: " + mdl);
+ System.out.println(">>> Logistic regression model: " + mdl);
- int amountOfErrors = 0;
- int totalAmount = 0;
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0}, {0, 0}};
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(new DenseVector(inputs));
- totalAmount++;
- if(groundTruth != prediction)
- amountOfErrors++;
+ totalAmount++;
+ if(groundTruth != prediction)
+ amountOfErrors++;
- int idx1 = (int)prediction;
- int idx2 = (int)groundTruth;
+ int idx1 = (int)prediction;
+ int idx2 = (int)groundTruth;
- confusionMtx[idx1][idx2]++;
+ confusionMtx[idx1][idx2]++;
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
System.out.println(">>> ---------------------------------");
- System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
- });
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ }
- igniteThread.start();
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println(">>> ---------------------------------");
- igniteThread.join();
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
index e670f01..ff2761a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
@@ -35,7 +35,6 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalcula
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassModel;
import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiClassTrainer;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run Logistic Regression multi-class classification trainer ({@link LogRegressionMultiClassModel}) over distributed
@@ -62,115 +61,109 @@ public class LogRegressionMultiClassClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- LogRegressionMultiClassClassificationExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
+ IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
- LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
- .withUpdatesStgy(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ))
- .withAmountOfIterations(100000)
- .withAmountOfLocIterations(10)
- .withBatchSize(100)
- .withSeed(123L);
+ LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ))
+ .withAmountOfIterations(100000)
+ .withAmountOfLocIterations(10)
+ .withBatchSize(100)
+ .withSeed(123L);
- LogRegressionMultiClassModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- },
- (k, v) -> v.get(0)
- );
+ LogRegressionMultiClassModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> {
+ double[] arr = v.asArray();
+ return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
+ },
+ (k, v) -> v.get(0)
+ );
- System.out.println(">>> SVM Multi-class model");
- System.out.println(mdl.toString());
+ System.out.println(">>> SVM Multi-class model");
+ System.out.println(mdl.toString());
- MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
+ MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
- IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
- ignite,
- dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- }
- );
-
- LogRegressionMultiClassModel mdlWithNormalization = trainer.fit(
- ignite,
- dataCache,
- preprocessor,
- (k, v) -> v.get(0)
- );
+ IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> {
+ double[] arr = v.asArray();
+ return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
+ }
+ );
- System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling");
- System.out.println(mdlWithNormalization.toString());
+ LogRegressionMultiClassModel mdlWithNormalization = trainer.fit(
+ ignite,
+ dataCache,
+ preprocessor,
+ (k, v) -> v.get(0)
+ );
- System.out.println(">>> ----------------------------------------------------------------");
- System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
- System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling");
+ System.out.println(mdlWithNormalization.toString());
- int amountOfErrors = 0;
- int amountOfErrorsWithNormalization = 0;
- int totalAmount = 0;
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println(">>> | Prediction\t| Prediction with Normalization\t| Ground Truth\t|");
+ System.out.println(">>> ----------------------------------------------------------------");
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
- int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+ int amountOfErrors = 0;
+ int amountOfErrorsWithNormalization = 0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- double[] val = observation.getValue().asArray();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+ int[][] confusionMtxWithNormalization = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
- double prediction = mdl.apply(new DenseVector(inputs));
- double predictionWithNormalization = mdlWithNormalization.apply(new DenseVector(inputs));
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ double[] val = observation.getValue().asArray();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- totalAmount++;
+ double prediction = mdl.apply(new DenseVector(inputs));
+ double predictionWithNormalization = mdlWithNormalization.apply(new DenseVector(inputs));
- // Collect data for model
- if(groundTruth != prediction)
- amountOfErrors++;
+ totalAmount++;
- int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
- int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+ // Collect data for model
+ if(groundTruth != prediction)
+ amountOfErrors++;
- confusionMtx[idx1][idx2]++;
+ int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
+ int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
- // Collect data for model with minmaxscaling
- if(groundTruth != predictionWithNormalization)
- amountOfErrorsWithNormalization++;
+ confusionMtx[idx1][idx2]++;
- idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2);
- idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+ // Collect data for model with minmaxscaling
+ if(groundTruth != predictionWithNormalization)
+ amountOfErrorsWithNormalization++;
- confusionMtxWithNormalization[idx1][idx2]++;
+ idx1 = (int)predictionWithNormalization == 1 ? 0 : ((int)predictionWithNormalization == 3 ? 1 : 2);
+ idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
- }
- System.out.println(">>> ----------------------------------------------------------------");
- System.out.println("\n>>> -----------------Logistic Regression model-------------");
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ confusionMtxWithNormalization[idx1][idx2]++;
- System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------");
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount));
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
-
- System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed.");
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithNormalization, groundTruth);
}
- });
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println("\n>>> -----------------Logistic Regression model-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+
+ System.out.println("\n>>> -----------------Logistic Regression model with Normalization-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithNormalization);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithNormalization / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithNormalization));
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example completed.");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index eb4c8f3..25ce156 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -24,13 +24,11 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.examples.ml.tree.DecisionTreeClassificationTrainerExample;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run <a href="https://en.wikipedia.org/wiki/Decision_tree">decision tree</a> classification with
@@ -54,46 +52,38 @@ public class CrossValidationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> {
+ // Create cache with training data.
+ CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ trainingSetCfg.setName("TRAINING_SET");
+ trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
- // Create cache with training data.
- CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
- trainingSetCfg.setName("TRAINING_SET");
- trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+ IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
- IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+ Random rnd = new Random(0);
- Random rnd = new Random(0);
+ // Fill training data.
+ for (int i = 0; i < 1000; i++)
+ trainingSet.put(i, generatePoint(rnd));
- // Fill training data.
- for (int i = 0; i < 1000; i++)
- trainingSet.put(i, generatePoint(rnd));
+ // Create classification trainer.
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
- // Create classification trainer.
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+ CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
+ = new CrossValidation<>();
- CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
- = new CrossValidation<>();
+ double[] scores = scoreCalculator.score(
+ trainer,
+ new Accuracy<>(),
+ ignite,
+ trainingSet,
+ (k, v) -> VectorUtils.of(v.x, v.y),
+ (k, v) -> v.lb,
+ 4
+ );
- double[] scores = scoreCalculator.score(
- trainer,
- new Accuracy<>(),
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> v.lb,
- 4
- );
+ System.out.println(">>> Accuracy: " + Arrays.toString(scores));
- System.out.println(">>> Accuracy: " + Arrays.toString(scores));
-
- System.out.println(">>> Cross validation score calculator example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ System.out.println(">>> Cross validation score calculator example completed.");
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
index fa1c2ca..8b104f5 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
@@ -31,7 +31,6 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run linear regression model over dataset split on train and test subsets ({@link TrainTestDatasetSplitter}).
@@ -113,55 +112,47 @@ public class TrainTestDatasetSplitterExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- TrainTestDatasetSplitterExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Create new training dataset splitter object.");
- TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
- .split(0.75);
+ System.out.println(">>> Create new training dataset splitter object.");
+ TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
+ .split(0.75);
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- split.getTrainFilter(),
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ System.out.println(">>> Perform the training to get the model.");
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ split.getTrainFilter(),
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- ScanQuery<Integer, double[]> qry = new ScanQuery<>();
- qry.setFilter(split.getTestFilter());
+ ScanQuery<Integer, double[]> qry = new ScanQuery<>();
+ qry.setFilter(split.getTestFilter());
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(new DenseVector(inputs));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
+ }
- System.out.println(">>> ---------------------------------");
-
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
- });
-
- igniteThread.start();
-
- igniteThread.join();
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
}
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/609266fe/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
index f71db2d..c219441 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/binary/SVMBinaryClassificationExample.java
@@ -29,7 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationTrainer;
-import org.apache.ignite.thread.IgniteThread;
/**
* Run SVM binary-class classification model ({@link SVMLinearBinaryClassificationModel}) over distributed dataset.
@@ -54,64 +53,58 @@ public class SVMBinaryClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- SVMBinaryClassificationExample.class.getSimpleName(), () -> {
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
- SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
+ SVMLinearBinaryClassificationTrainer trainer = new SVMLinearBinaryClassificationTrainer();
- SVMLinearBinaryClassificationModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- );
+ SVMLinearBinaryClassificationModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
- System.out.println(">>> SVM model " + mdl);
+ System.out.println(">>> SVM model " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
-
- int amountOfErrors = 0;
- int totalAmount = 0;
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0}, {0, 0}};
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0}, {0, 0}};
- double prediction = mdl.apply(new DenseVector(inputs));
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- totalAmount++;
- if(groundTruth != prediction)
- amountOfErrors++;
+ double prediction = mdl.apply(new DenseVector(inputs));
- int idx1 = prediction == 0.0 ? 0 : 1;
- int idx2 = groundTruth == 0.0 ? 0 : 1;
+ totalAmount++;
+ if(groundTruth != prediction)
+ amountOfErrors++;
- confusionMtx[idx1][idx2]++;
+ int idx1 = prediction == 0.0 ? 0 : 1;
+ int idx2 = groundTruth == 0.0 ? 0 : 1;
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ confusionMtx[idx1][idx2]++;
- System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ }
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
- });
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- igniteThread.start();
- igniteThread.join();
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
}
}