You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/10/30 06:08:58 UTC
[07/28] ignite git commit: IGNITE-9910: [ML] Move the static
copy-pasted datasets from examples to special Util class
IGNITE-9910: [ML] Move the static copy-pasted datasets from examples
to special Util class
this closes #5028
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/370cd3e1
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/370cd3e1
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/370cd3e1
Branch: refs/heads/ignite-627
Commit: 370cd3e1d60237e4a238c1c789cddbd1164e57e6
Parents: c7449f6
Author: zaleslaw <za...@gmail.com>
Authored: Fri Oct 26 16:06:42 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Oct 26 16:06:42 2018 +0300
----------------------------------------------------------------------
.../clustering/KMeansClusterizationExample.java | 150 +----
.../ml/knn/ANNClassificationExample.java | 2 +-
.../ml/knn/KNNClassificationExample.java | 183 +------
.../examples/ml/knn/KNNRegressionExample.java | 221 +-------
.../examples/ml/nn/MLPTrainerExample.java | 2 +-
.../LinearRegressionLSQRTrainerExample.java | 86 +--
...ssionLSQRTrainerWithMinMaxScalerExample.java | 85 +--
.../LinearRegressionSGDTrainerExample.java | 86 +--
.../LogisticRegressionSGDTrainerExample.java | 132 +----
...gressionMultiClassClassificationExample.java | 158 +-----
.../ml/selection/cv/CrossValidationExample.java | 2 +-
.../split/TrainTestDatasetSplitterExample.java | 90 +--
.../binary/SVMBinaryClassificationExample.java | 132 +----
.../SVMMultiClassClassificationExample.java | 189 ++-----
...ecisionTreeClassificationTrainerExample.java | 2 +-
.../DecisionTreeRegressionTrainerExample.java | 2 +-
.../GDBOnTreesClassificationTrainerExample.java | 2 +-
.../RandomForestClassificationExample.java | 216 +-------
.../RandomForestRegressionExample.java | 543 +------------------
.../examples/ml/util/MLSandboxDatasets.java | 87 +++
.../ignite/examples/ml/util/SandboxMLCache.java | 144 +++++
.../ignite/examples/ml/util/TestCache.java | 77 ---
.../datasets/boston_housing_dataset.txt | 505 +++++++++++++++++
.../resources/datasets/cleared_machines.csv | 209 +++++++
.../resources/datasets/cleared_machines.txt | 209 -------
.../resources/datasets/glass_identification.csv | 116 ++++
.../main/resources/datasets/mortalitydata.csv | 53 ++
.../resources/datasets/two_classed_iris.csv | 100 ++++
examples/src/main/resources/datasets/wine.txt | 178 ++++++
.../ml/knn/regression/KNNRegressionModel.java | 2 +
.../math/primitives/vector/AbstractVector.java | 9 +
.../ml/math/primitives/vector/Vector.java | 9 +
.../vector/impl/DelegatingVector.java | 5 +
33 files changed, 1643 insertions(+), 2343 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
index 567775b..3c8eeaa 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
@@ -17,19 +17,19 @@
package org.apache.ignite.examples.ml.clustering;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.math.Tracer;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Run KMeans clustering algorithm ({@link KMeansTrainer}) over distributed dataset.
@@ -47,14 +47,15 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
*/
public class KMeansClusterizationExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> KMeans clustering algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
KMeansTrainer trainer = new KMeansTrainer()
.withSeed(7867L);
@@ -62,8 +63,8 @@ public class KMeansClusterizationExample {
KMeansModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
);
System.out.println(">>> KMeans centroids");
@@ -71,139 +72,24 @@ public class KMeansClusterizationExample {
Tracer.showAscii(mdl.getCenters()[1]);
System.out.println(">>>");
- System.out.println(">>> -----------------------------------");
- System.out.println(">>> | Predicted cluster\t| Real Label\t|");
- System.out.println(">>> -----------------------------------");
+ System.out.println(">>> --------------------------------------------");
+ System.out.println(">>> | Predicted cluster\t| Erased class label\t|");
+ System.out.println(">>> --------------------------------------------");
- int amountOfErrors = 0;
- int totalAmount = 0;
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
-
- double prediction = mdl.apply(new DenseVector(inputs));
-
- totalAmount++;
- if (groundTruth != prediction)
- amountOfErrors++;
+ double prediction = mdl.apply(inputs);
System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
System.out.println(">>> ---------------------------------");
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
-
System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
}
}
}
-
- /** The Iris dataset. */
- private static final double[][] data = {
- {0, 5.1, 3.5, 1.4, 0.2},
- {0, 4.9, 3, 1.4, 0.2},
- {0, 4.7, 3.2, 1.3, 0.2},
- {0, 4.6, 3.1, 1.5, 0.2},
- {0, 5, 3.6, 1.4, 0.2},
- {0, 5.4, 3.9, 1.7, 0.4},
- {0, 4.6, 3.4, 1.4, 0.3},
- {0, 5, 3.4, 1.5, 0.2},
- {0, 4.4, 2.9, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5.4, 3.7, 1.5, 0.2},
- {0, 4.8, 3.4, 1.6, 0.2},
- {0, 4.8, 3, 1.4, 0.1},
- {0, 4.3, 3, 1.1, 0.1},
- {0, 5.8, 4, 1.2, 0.2},
- {0, 5.7, 4.4, 1.5, 0.4},
- {0, 5.4, 3.9, 1.3, 0.4},
- {0, 5.1, 3.5, 1.4, 0.3},
- {0, 5.7, 3.8, 1.7, 0.3},
- {0, 5.1, 3.8, 1.5, 0.3},
- {0, 5.4, 3.4, 1.7, 0.2},
- {0, 5.1, 3.7, 1.5, 0.4},
- {0, 4.6, 3.6, 1, 0.2},
- {0, 5.1, 3.3, 1.7, 0.5},
- {0, 4.8, 3.4, 1.9, 0.2},
- {0, 5, 3, 1.6, 0.2},
- {0, 5, 3.4, 1.6, 0.4},
- {0, 5.2, 3.5, 1.5, 0.2},
- {0, 5.2, 3.4, 1.4, 0.2},
- {0, 4.7, 3.2, 1.6, 0.2},
- {0, 4.8, 3.1, 1.6, 0.2},
- {0, 5.4, 3.4, 1.5, 0.4},
- {0, 5.2, 4.1, 1.5, 0.1},
- {0, 5.5, 4.2, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5, 3.2, 1.2, 0.2},
- {0, 5.5, 3.5, 1.3, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 4.4, 3, 1.3, 0.2},
- {0, 5.1, 3.4, 1.5, 0.2},
- {0, 5, 3.5, 1.3, 0.3},
- {0, 4.5, 2.3, 1.3, 0.3},
- {0, 4.4, 3.2, 1.3, 0.2},
- {0, 5, 3.5, 1.6, 0.6},
- {0, 5.1, 3.8, 1.9, 0.4},
- {0, 4.8, 3, 1.4, 0.3},
- {0, 5.1, 3.8, 1.6, 0.2},
- {0, 4.6, 3.2, 1.4, 0.2},
- {0, 5.3, 3.7, 1.5, 0.2},
- {0, 5, 3.3, 1.4, 0.2},
- {1, 7, 3.2, 4.7, 1.4},
- {1, 6.4, 3.2, 4.5, 1.5},
- {1, 6.9, 3.1, 4.9, 1.5},
- {1, 5.5, 2.3, 4, 1.3},
- {1, 6.5, 2.8, 4.6, 1.5},
- {1, 5.7, 2.8, 4.5, 1.3},
- {1, 6.3, 3.3, 4.7, 1.6},
- {1, 4.9, 2.4, 3.3, 1},
- {1, 6.6, 2.9, 4.6, 1.3},
- {1, 5.2, 2.7, 3.9, 1.4},
- {1, 5, 2, 3.5, 1},
- {1, 5.9, 3, 4.2, 1.5},
- {1, 6, 2.2, 4, 1},
- {1, 6.1, 2.9, 4.7, 1.4},
- {1, 5.6, 2.9, 3.6, 1.3},
- {1, 6.7, 3.1, 4.4, 1.4},
- {1, 5.6, 3, 4.5, 1.5},
- {1, 5.8, 2.7, 4.1, 1},
- {1, 6.2, 2.2, 4.5, 1.5},
- {1, 5.6, 2.5, 3.9, 1.1},
- {1, 5.9, 3.2, 4.8, 1.8},
- {1, 6.1, 2.8, 4, 1.3},
- {1, 6.3, 2.5, 4.9, 1.5},
- {1, 6.1, 2.8, 4.7, 1.2},
- {1, 6.4, 2.9, 4.3, 1.3},
- {1, 6.6, 3, 4.4, 1.4},
- {1, 6.8, 2.8, 4.8, 1.4},
- {1, 6.7, 3, 5, 1.7},
- {1, 6, 2.9, 4.5, 1.5},
- {1, 5.7, 2.6, 3.5, 1},
- {1, 5.5, 2.4, 3.8, 1.1},
- {1, 5.5, 2.4, 3.7, 1},
- {1, 5.8, 2.7, 3.9, 1.2},
- {1, 6, 2.7, 5.1, 1.6},
- {1, 5.4, 3, 4.5, 1.5},
- {1, 6, 3.4, 4.5, 1.6},
- {1, 6.7, 3.1, 4.7, 1.5},
- {1, 6.3, 2.3, 4.4, 1.3},
- {1, 5.6, 3, 4.1, 1.3},
- {1, 5.5, 2.5, 4, 1.3},
- {1, 5.5, 2.6, 4.4, 1.2},
- {1, 6.1, 3, 4.6, 1.4},
- {1, 5.8, 2.6, 4, 1.2},
- {1, 5, 2.3, 3.3, 1},
- {1, 5.6, 2.7, 4.2, 1.3},
- {1, 5.7, 3, 4.2, 1.2},
- {1, 5.7, 2.9, 4.2, 1.3},
- {1, 6.2, 2.9, 4.3, 1.3},
- {1, 5.1, 2.5, 3, 1.1},
- {1, 5.7, 2.8, 4.1, 1.3},
- };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
index c9490fc..419eccb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
@@ -51,7 +51,7 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
*/
public class ANNClassificationExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
System.out.println();
System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 5cbb2ad..31ecdac 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -17,20 +17,20 @@
package org.apache.ignite.examples.ml.knn;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Run kNN multi-class classification trainer ({@link KNNClassificationTrainer}) over distributed dataset.
@@ -48,22 +48,23 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
*/
public class KNNClassificationExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> kNN multi-class classification algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.IRIS);
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
NNClassificationModel knnMdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(NNStrategy.WEIGHTED);
@@ -75,13 +76,13 @@ public class KNNClassificationExample {
int amountOfErrors = 0;
int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = knnMdl.apply(new DenseVector(inputs));
+ double prediction = knnMdl.apply(inputs);
totalAmount++;
if (groundTruth != prediction)
@@ -99,158 +100,4 @@ public class KNNClassificationExample {
}
}
}
-
- /** The Iris dataset. */
- private static final double[][] data = {
- {1, 5.1, 3.5, 1.4, 0.2},
- {1, 4.9, 3, 1.4, 0.2},
- {1, 4.7, 3.2, 1.3, 0.2},
- {1, 4.6, 3.1, 1.5, 0.2},
- {1, 5, 3.6, 1.4, 0.2},
- {1, 5.4, 3.9, 1.7, 0.4},
- {1, 4.6, 3.4, 1.4, 0.3},
- {1, 5, 3.4, 1.5, 0.2},
- {1, 4.4, 2.9, 1.4, 0.2},
- {1, 4.9, 3.1, 1.5, 0.1},
- {1, 5.4, 3.7, 1.5, 0.2},
- {1, 4.8, 3.4, 1.6, 0.2},
- {1, 4.8, 3, 1.4, 0.1},
- {1, 4.3, 3, 1.1, 0.1},
- {1, 5.8, 4, 1.2, 0.2},
- {1, 5.7, 4.4, 1.5, 0.4},
- {1, 5.4, 3.9, 1.3, 0.4},
- {1, 5.1, 3.5, 1.4, 0.3},
- {1, 5.7, 3.8, 1.7, 0.3},
- {1, 5.1, 3.8, 1.5, 0.3},
- {1, 5.4, 3.4, 1.7, 0.2},
- {1, 5.1, 3.7, 1.5, 0.4},
- {1, 4.6, 3.6, 1, 0.2},
- {1, 5.1, 3.3, 1.7, 0.5},
- {1, 4.8, 3.4, 1.9, 0.2},
- {1, 5, 3, 1.6, 0.2},
- {1, 5, 3.4, 1.6, 0.4},
- {1, 5.2, 3.5, 1.5, 0.2},
- {1, 5.2, 3.4, 1.4, 0.2},
- {1, 4.7, 3.2, 1.6, 0.2},
- {1, 4.8, 3.1, 1.6, 0.2},
- {1, 5.4, 3.4, 1.5, 0.4},
- {1, 5.2, 4.1, 1.5, 0.1},
- {1, 5.5, 4.2, 1.4, 0.2},
- {1, 4.9, 3.1, 1.5, 0.1},
- {1, 5, 3.2, 1.2, 0.2},
- {1, 5.5, 3.5, 1.3, 0.2},
- {1, 4.9, 3.1, 1.5, 0.1},
- {1, 4.4, 3, 1.3, 0.2},
- {1, 5.1, 3.4, 1.5, 0.2},
- {1, 5, 3.5, 1.3, 0.3},
- {1, 4.5, 2.3, 1.3, 0.3},
- {1, 4.4, 3.2, 1.3, 0.2},
- {1, 5, 3.5, 1.6, 0.6},
- {1, 5.1, 3.8, 1.9, 0.4},
- {1, 4.8, 3, 1.4, 0.3},
- {1, 5.1, 3.8, 1.6, 0.2},
- {1, 4.6, 3.2, 1.4, 0.2},
- {1, 5.3, 3.7, 1.5, 0.2},
- {1, 5, 3.3, 1.4, 0.2},
- {2, 7, 3.2, 4.7, 1.4},
- {2, 6.4, 3.2, 4.5, 1.5},
- {2, 6.9, 3.1, 4.9, 1.5},
- {2, 5.5, 2.3, 4, 1.3},
- {2, 6.5, 2.8, 4.6, 1.5},
- {2, 5.7, 2.8, 4.5, 1.3},
- {2, 6.3, 3.3, 4.7, 1.6},
- {2, 4.9, 2.4, 3.3, 1},
- {2, 6.6, 2.9, 4.6, 1.3},
- {2, 5.2, 2.7, 3.9, 1.4},
- {2, 5, 2, 3.5, 1},
- {2, 5.9, 3, 4.2, 1.5},
- {2, 6, 2.2, 4, 1},
- {2, 6.1, 2.9, 4.7, 1.4},
- {2, 5.6, 2.9, 3.6, 1.3},
- {2, 6.7, 3.1, 4.4, 1.4},
- {2, 5.6, 3, 4.5, 1.5},
- {2, 5.8, 2.7, 4.1, 1},
- {2, 6.2, 2.2, 4.5, 1.5},
- {2, 5.6, 2.5, 3.9, 1.1},
- {2, 5.9, 3.2, 4.8, 1.8},
- {2, 6.1, 2.8, 4, 1.3},
- {2, 6.3, 2.5, 4.9, 1.5},
- {2, 6.1, 2.8, 4.7, 1.2},
- {2, 6.4, 2.9, 4.3, 1.3},
- {2, 6.6, 3, 4.4, 1.4},
- {2, 6.8, 2.8, 4.8, 1.4},
- {2, 6.7, 3, 5, 1.7},
- {2, 6, 2.9, 4.5, 1.5},
- {2, 5.7, 2.6, 3.5, 1},
- {2, 5.5, 2.4, 3.8, 1.1},
- {2, 5.5, 2.4, 3.7, 1},
- {2, 5.8, 2.7, 3.9, 1.2},
- {2, 6, 2.7, 5.1, 1.6},
- {2, 5.4, 3, 4.5, 1.5},
- {2, 6, 3.4, 4.5, 1.6},
- {2, 6.7, 3.1, 4.7, 1.5},
- {2, 6.3, 2.3, 4.4, 1.3},
- {2, 5.6, 3, 4.1, 1.3},
- {2, 5.5, 2.5, 4, 1.3},
- {2, 5.5, 2.6, 4.4, 1.2},
- {2, 6.1, 3, 4.6, 1.4},
- {2, 5.8, 2.6, 4, 1.2},
- {2, 5, 2.3, 3.3, 1},
- {2, 5.6, 2.7, 4.2, 1.3},
- {2, 5.7, 3, 4.2, 1.2},
- {2, 5.7, 2.9, 4.2, 1.3},
- {2, 6.2, 2.9, 4.3, 1.3},
- {2, 5.1, 2.5, 3, 1.1},
- {2, 5.7, 2.8, 4.1, 1.3},
- {3, 6.3, 3.3, 6, 2.5},
- {3, 5.8, 2.7, 5.1, 1.9},
- {3, 7.1, 3, 5.9, 2.1},
- {3, 6.3, 2.9, 5.6, 1.8},
- {3, 6.5, 3, 5.8, 2.2},
- {3, 7.6, 3, 6.6, 2.1},
- {3, 4.9, 2.5, 4.5, 1.7},
- {3, 7.3, 2.9, 6.3, 1.8},
- {3, 6.7, 2.5, 5.8, 1.8},
- {3, 7.2, 3.6, 6.1, 2.5},
- {3, 6.5, 3.2, 5.1, 2},
- {3, 6.4, 2.7, 5.3, 1.9},
- {3, 6.8, 3, 5.5, 2.1},
- {3, 5.7, 2.5, 5, 2},
- {3, 5.8, 2.8, 5.1, 2.4},
- {3, 6.4, 3.2, 5.3, 2.3},
- {3, 6.5, 3, 5.5, 1.8},
- {3, 7.7, 3.8, 6.7, 2.2},
- {3, 7.7, 2.6, 6.9, 2.3},
- {3, 6, 2.2, 5, 1.5},
- {3, 6.9, 3.2, 5.7, 2.3},
- {3, 5.6, 2.8, 4.9, 2},
- {3, 7.7, 2.8, 6.7, 2},
- {3, 6.3, 2.7, 4.9, 1.8},
- {3, 6.7, 3.3, 5.7, 2.1},
- {3, 7.2, 3.2, 6, 1.8},
- {3, 6.2, 2.8, 4.8, 1.8},
- {3, 6.1, 3, 4.9, 1.8},
- {3, 6.4, 2.8, 5.6, 2.1},
- {3, 7.2, 3, 5.8, 1.6},
- {3, 7.4, 2.8, 6.1, 1.9},
- {3, 7.9, 3.8, 6.4, 2},
- {3, 6.4, 2.8, 5.6, 2.2},
- {3, 6.3, 2.8, 5.1, 1.5},
- {3, 6.1, 2.6, 5.6, 1.4},
- {3, 7.7, 3, 6.1, 2.3},
- {3, 6.3, 3.4, 5.6, 2.4},
- {3, 6.4, 3.1, 5.5, 1.8},
- {3, 6, 3, 4.8, 1.8},
- {3, 6.9, 3.1, 5.4, 2.1},
- {3, 6.7, 3.1, 5.6, 2.4},
- {3, 6.9, 3.1, 5.1, 2.3},
- {3, 5.8, 2.7, 5.1, 1.9},
- {3, 6.8, 3.2, 5.9, 2.3},
- {3, 6.7, 3.3, 5.7, 2.5},
- {3, 6.7, 3, 5.2, 2.3},
- {3, 6.3, 2.5, 5, 1.9},
- {3, 6.5, 3, 5.2, 2},
- {3, 6.2, 3.4, 5.4, 2.3},
- {3, 5.9, 3, 5.1, 1.8}
- };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index 3969f0c..9917e80 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -17,20 +17,20 @@
package org.apache.ignite.examples.ml.knn;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Run kNN regression trainer ({@link KNNRegressionTrainer}) over distributed dataset.
@@ -49,22 +49,23 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
*/
public class KNNRegressionExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> kNN regression over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
).withK(5)
.withDistanceMeasure(new ManhattanDistance())
.withStrategy(NNStrategy.WEIGHTED);
@@ -79,13 +80,13 @@ public class KNNRegressionExample {
// Calculate mean absolute error (MAE)
double mae = 0.0;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = knnMdl.apply(new DenseVector(inputs));
+ double prediction = knnMdl.apply(inputs);
mse += Math.pow(prediction - groundTruth, 2.0);
mae += Math.abs(prediction - groundTruth);
@@ -107,196 +108,4 @@ public class KNNRegressionExample {
}
}
}
-
- /** The Iris dataset. */
- private static final double[][] data = {
- {199, 125, 256, 6000, 256, 16, 128},
- {253, 29, 8000, 32000, 32, 8, 32},
- {132, 29, 8000, 16000, 32, 8, 16},
- {290, 26, 8000, 32000, 64, 8, 32},
- {381, 23, 16000, 32000, 64, 16, 32},
- {749, 23, 16000, 64000, 64, 16, 32},
- {1238, 23, 32000, 64000, 128, 32, 64},
- {23, 400, 1000, 3000, 0, 1, 2},
- {24, 400, 512, 3500, 4, 1, 6},
- {70, 60, 2000, 8000, 65, 1, 8},
- {117, 50, 4000, 16000, 65, 1, 8},
- {15, 350, 64, 64, 0, 1, 4},
- {64, 200, 512, 16000, 0, 4, 32},
- {23, 167, 524, 2000, 8, 4, 15},
- {29, 143, 512, 5000, 0, 7, 32},
- {22, 143, 1000, 2000, 0, 5, 16},
- {124, 110, 5000, 5000, 142, 8, 64},
- {35, 143, 1500, 6300, 0, 5, 32},
- {39, 143, 3100, 6200, 0, 5, 20},
- {40, 143, 2300, 6200, 0, 6, 64},
- {45, 110, 3100, 6200, 0, 6, 64},
- {28, 320, 128, 6000, 0, 1, 12},
- {21, 320, 512, 2000, 4, 1, 3},
- {28, 320, 256, 6000, 0, 1, 6},
- {22, 320, 256, 3000, 4, 1, 3},
- {28, 320, 512, 5000, 4, 1, 5},
- {27, 320, 256, 5000, 4, 1, 6},
- {102, 25, 1310, 2620, 131, 12, 24},
- {74, 50, 2620, 10480, 30, 12, 24},
- {138, 56, 5240, 20970, 30, 12, 24},
- {136, 64, 5240, 20970, 30, 12, 24},
- {23, 50, 500, 2000, 8, 1, 4},
- {29, 50, 1000, 4000, 8, 1, 5},
- {44, 50, 2000, 8000, 8, 1, 5},
- {30, 50, 1000, 4000, 8, 3, 5},
- {41, 50, 1000, 8000, 8, 3, 5},
- {74, 50, 2000, 16000, 8, 3, 5},
- {54, 133, 1000, 12000, 9, 3, 12},
- {41, 133, 1000, 8000, 9, 3, 12},
- {18, 810, 512, 512, 8, 1, 1},
- {28, 810, 1000, 5000, 0, 1, 1},
- {36, 320, 512, 8000, 4, 1, 5},
- {38, 200, 512, 8000, 8, 1, 8},
- {34, 700, 384, 8000, 0, 1, 1},
- {19, 700, 256, 2000, 0, 1, 1},
- {72, 140, 1000, 16000, 16, 1, 3},
- {36, 200, 1000, 8000, 0, 1, 2},
- {30, 110, 1000, 4000, 16, 1, 2},
- {56, 110, 1000, 12000, 16, 1, 2},
- {42, 220, 1000, 8000, 16, 1, 2},
- {34, 800, 256, 8000, 0, 1, 4},
- {19, 125, 512, 1000, 0, 8, 20},
- {75, 75, 2000, 8000, 64, 1, 38},
- {113, 75, 2000, 16000, 64, 1, 38},
- {157, 75, 2000, 16000, 128, 1, 38},
- {18, 90, 256, 1000, 0, 3, 10},
- {20, 105, 256, 2000, 0, 3, 10},
- {28, 105, 1000, 4000, 0, 3, 24},
- {33, 105, 2000, 4000, 8, 3, 19},
- {47, 75, 2000, 8000, 8, 3, 24},
- {54, 75, 3000, 8000, 8, 3, 48},
- {20, 175, 256, 2000, 0, 3, 24},
- {23, 300, 768, 3000, 0, 6, 24},
- {25, 300, 768, 3000, 6, 6, 24},
- {52, 300, 768, 12000, 6, 6, 24},
- {27, 300, 768, 4500, 0, 1, 24},
- {50, 300, 384, 12000, 6, 1, 24},
- {18, 300, 192, 768, 6, 6, 24},
- {53, 180, 768, 12000, 6, 1, 31},
- {23, 330, 1000, 3000, 0, 2, 4},
- {30, 300, 1000, 4000, 8, 3, 64},
- {73, 300, 1000, 16000, 8, 2, 112},
- {20, 330, 1000, 2000, 0, 1, 2},
- {25, 330, 1000, 4000, 0, 3, 6},
- {28, 140, 2000, 4000, 0, 3, 6},
- {29, 140, 2000, 4000, 0, 4, 8},
- {32, 140, 2000, 4000, 8, 1, 20},
- {175, 140, 2000, 32000, 32, 1, 20},
- {57, 140, 2000, 8000, 32, 1, 54},
- {181, 140, 2000, 32000, 32, 1, 54},
- {32, 140, 2000, 4000, 8, 1, 20},
- {82, 57, 4000, 16000, 1, 6, 12},
- {171, 57, 4000, 24000, 64, 12, 16},
- {361, 26, 16000, 32000, 64, 16, 24},
- {350, 26, 16000, 32000, 64, 8, 24},
- {220, 26, 8000, 32000, 0, 8, 24},
- {113, 26, 8000, 16000, 0, 8, 16},
- {15, 480, 96, 512, 0, 1, 1},
- {21, 203, 1000, 2000, 0, 1, 5},
- {35, 115, 512, 6000, 16, 1, 6},
- {18, 1100, 512, 1500, 0, 1, 1},
- {20, 1100, 768, 2000, 0, 1, 1},
- {20, 600, 768, 2000, 0, 1, 1},
- {28, 400, 2000, 4000, 0, 1, 1},
- {45, 400, 4000, 8000, 0, 1, 1},
- {18, 900, 1000, 1000, 0, 1, 2},
- {17, 900, 512, 1000, 0, 1, 2},
- {26, 900, 1000, 4000, 4, 1, 2},
- {28, 900, 1000, 4000, 8, 1, 2},
- {28, 900, 2000, 4000, 0, 3, 6},
- {31, 225, 2000, 4000, 8, 3, 6},
- {42, 180, 2000, 8000, 8, 1, 6},
- {76, 185, 2000, 16000, 16, 1, 6},
- {76, 180, 2000, 16000, 16, 1, 6},
- {26, 225, 1000, 4000, 2, 3, 6},
- {59, 25, 2000, 12000, 8, 1, 4},
- {65, 25, 2000, 12000, 16, 3, 5},
- {101, 17, 4000, 16000, 8, 6, 12},
- {116, 17, 4000, 16000, 32, 6, 12},
- {18, 1500, 768, 1000, 0, 0, 0},
- {20, 1500, 768, 2000, 0, 0, 0},
- {20, 800, 768, 2000, 0, 0, 0},
- {30, 50, 2000, 4000, 0, 3, 6},
- {44, 50, 2000, 8000, 8, 3, 6},
- {82, 50, 2000, 16000, 24, 1, 6},
- {128, 50, 8000, 16000, 48, 1, 10},
- {37, 100, 1000, 8000, 0, 2, 6},
- {46, 100, 1000, 8000, 24, 2, 6},
- {46, 100, 1000, 8000, 24, 3, 6},
- {80, 50, 2000, 16000, 12, 3, 16},
- {88, 50, 2000, 16000, 24, 6, 16},
- {33, 150, 512, 4000, 0, 8, 128},
- {46, 115, 2000, 8000, 16, 1, 3},
- {29, 115, 2000, 4000, 2, 1, 5},
- {53, 92, 2000, 8000, 32, 1, 6},
- {41, 92, 2000, 8000, 4, 1, 6},
- {86, 75, 4000, 16000, 16, 1, 6},
- {95, 60, 4000, 16000, 32, 1, 6},
- {107, 60, 2000, 16000, 64, 5, 8},
- {117, 60, 4000, 16000, 64, 5, 8},
- {119, 50, 4000, 16000, 64, 5, 10},
- {120, 72, 4000, 16000, 64, 8, 16},
- {48, 72, 2000, 8000, 16, 6, 8},
- {126, 40, 8000, 16000, 32, 8, 16},
- {266, 40, 8000, 32000, 64, 8, 24},
- {270, 35, 8000, 32000, 64, 8, 24},
- {426, 38, 16000, 32000, 128, 16, 32},
- {151, 48, 4000, 24000, 32, 8, 24},
- {267, 38, 8000, 32000, 64, 8, 24},
- {603, 30, 16000, 32000, 256, 16, 24},
- {19, 112, 1000, 1000, 0, 1, 4},
- {21, 84, 1000, 2000, 0, 1, 6},
- {26, 56, 1000, 4000, 0, 1, 6},
- {35, 56, 2000, 6000, 0, 1, 8},
- {41, 56, 2000, 8000, 0, 1, 8},
- {47, 56, 4000, 8000, 0, 1, 8},
- {62, 56, 4000, 12000, 0, 1, 8},
- {78, 56, 4000, 16000, 0, 1, 8},
- {80, 38, 4000, 8000, 32, 16, 32},
- {142, 38, 8000, 16000, 64, 4, 8},
- {281, 38, 8000, 24000, 160, 4, 8},
- {190, 38, 4000, 16000, 128, 16, 32},
- {21, 200, 1000, 2000, 0, 1, 2},
- {25, 200, 1000, 4000, 0, 1, 4},
- {67, 200, 2000, 8000, 64, 1, 5},
- {24, 250, 512, 4000, 0, 1, 7},
- {24, 250, 512, 4000, 0, 4, 7},
- {64, 250, 1000, 16000, 1, 1, 8},
- {25, 160, 512, 4000, 2, 1, 5},
- {20, 160, 512, 2000, 2, 3, 8},
- {29, 160, 1000, 4000, 8, 1, 14},
- {43, 160, 1000, 8000, 16, 1, 14},
- {53, 160, 2000, 8000, 32, 1, 13},
- {19, 240, 512, 1000, 8, 1, 3},
- {22, 240, 512, 2000, 8, 1, 5},
- {31, 105, 2000, 4000, 8, 3, 8},
- {41, 105, 2000, 6000, 16, 6, 16},
- {47, 105, 2000, 8000, 16, 4, 14},
- {99, 52, 4000, 16000, 32, 4, 12},
- {67, 70, 4000, 12000, 8, 6, 8},
- {81, 59, 4000, 12000, 32, 6, 12},
- {149, 59, 8000, 16000, 64, 12, 24},
- {183, 26, 8000, 24000, 32, 8, 16},
- {275, 26, 8000, 32000, 64, 12, 16},
- {382, 26, 8000, 32000, 128, 24, 32},
- {56, 116, 2000, 8000, 32, 5, 28},
- {182, 50, 2000, 32000, 24, 6, 26},
- {227, 50, 2000, 32000, 48, 26, 52},
- {341, 50, 2000, 32000, 112, 52, 104},
- {360, 50, 4000, 32000, 112, 52, 104},
- {919, 30, 8000, 64000, 96, 12, 176},
- {978, 30, 8000, 64000, 128, 12, 176},
- {24, 180, 262, 4000, 0, 1, 3},
- {37, 124, 1000, 8000, 0, 1, 8},
- {50, 98, 1000, 8000, 32, 2, 8},
- {41, 125, 2000, 8000, 0, 2, 14},
- {47, 480, 512, 8000, 32, 0, 0},
- {25, 480, 1000, 4000, 0, 0, 0}
- };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
index 6d5745e..dc67aa1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
@@ -61,7 +61,7 @@ public class MLPTrainerExample {
*
* @param args Command line arguments, none required.
*/
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) {
// IMPL NOTE based on MLPGroupTrainerTest#testXOR
System.out.println(">>> Distributed multilayer perceptron example started.");
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 862a37f..aeb7a0d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -17,16 +17,16 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
@@ -44,72 +44,16 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
public class LinearRegressionLSQRTrainerExample {
- /** */
- private static final double[][] data = {
- {8, 78, 284, 9.100000381, 109},
- {9.300000191, 68, 433, 8.699999809, 144},
- {7.5, 70, 739, 7.199999809, 113},
- {8.899999619, 96, 1792, 8.899999619, 97},
- {10.19999981, 74, 477, 8.300000191, 206},
- {8.300000191, 111, 362, 10.89999962, 124},
- {8.800000191, 77, 671, 10, 152},
- {8.800000191, 168, 636, 9.100000381, 162},
- {10.69999981, 82, 329, 8.699999809, 150},
- {11.69999981, 89, 634, 7.599999905, 134},
- {8.5, 149, 631, 10.80000019, 292},
- {8.300000191, 60, 257, 9.5, 108},
- {8.199999809, 96, 284, 8.800000191, 111},
- {7.900000095, 83, 603, 9.5, 182},
- {10.30000019, 130, 686, 8.699999809, 129},
- {7.400000095, 145, 345, 11.19999981, 158},
- {9.600000381, 112, 1357, 9.699999809, 186},
- {9.300000191, 131, 544, 9.600000381, 177},
- {10.60000038, 80, 205, 9.100000381, 127},
- {9.699999809, 130, 1264, 9.199999809, 179},
- {11.60000038, 140, 688, 8.300000191, 80},
- {8.100000381, 154, 354, 8.399999619, 103},
- {9.800000191, 118, 1632, 9.399999619, 101},
- {7.400000095, 94, 348, 9.800000191, 117},
- {9.399999619, 119, 370, 10.39999962, 88},
- {11.19999981, 153, 648, 9.899999619, 78},
- {9.100000381, 116, 366, 9.199999809, 102},
- {10.5, 97, 540, 10.30000019, 95},
- {11.89999962, 176, 680, 8.899999619, 80},
- {8.399999619, 75, 345, 9.600000381, 92},
- {5, 134, 525, 10.30000019, 126},
- {9.800000191, 161, 870, 10.39999962, 108},
- {9.800000191, 111, 669, 9.699999809, 77},
- {10.80000019, 114, 452, 9.600000381, 60},
- {10.10000038, 142, 430, 10.69999981, 71},
- {10.89999962, 238, 822, 10.30000019, 86},
- {9.199999809, 78, 190, 10.69999981, 93},
- {8.300000191, 196, 867, 9.600000381, 106},
- {7.300000191, 125, 969, 10.5, 162},
- {9.399999619, 82, 499, 7.699999809, 95},
- {9.399999619, 125, 925, 10.19999981, 91},
- {9.800000191, 129, 353, 9.899999619, 52},
- {3.599999905, 84, 288, 8.399999619, 110},
- {8.399999619, 183, 718, 10.39999962, 69},
- {10.80000019, 119, 540, 9.199999809, 57},
- {10.10000038, 180, 668, 13, 106},
- {9, 82, 347, 8.800000191, 40},
- {10, 71, 345, 9.199999809, 50},
- {11.30000019, 118, 463, 7.800000191, 35},
- {11.30000019, 121, 728, 8.199999809, 86},
- {12.80000019, 68, 383, 7.400000095, 57},
- {10, 112, 316, 10.39999962, 57},
- {6.699999809, 109, 388, 8.899999619, 94}
- };
-
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Linear regression model over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
@@ -118,8 +62,8 @@ public class LinearRegressionLSQRTrainerExample {
LinearRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
);
System.out.println(">>> Linear regression model: " + mdl);
@@ -128,13 +72,13 @@ public class LinearRegressionLSQRTrainerExample {
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(inputs);
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
index 5692cb3..873cefb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
@@ -17,17 +17,17 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
@@ -50,84 +50,25 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
- /** */
- private static final double[][] data = {
- {8, 78, 284, 9.100000381, 109},
- {9.300000191, 68, 433, 8.699999809, 144},
- {7.5, 70, 739, 7.199999809, 113},
- {8.899999619, 96, 1792, 8.899999619, 97},
- {10.19999981, 74, 477, 8.300000191, 206},
- {8.300000191, 111, 362, 10.89999962, 124},
- {8.800000191, 77, 671, 10, 152},
- {8.800000191, 168, 636, 9.100000381, 162},
- {10.69999981, 82, 329, 8.699999809, 150},
- {11.69999981, 89, 634, 7.599999905, 134},
- {8.5, 149, 631, 10.80000019, 292},
- {8.300000191, 60, 257, 9.5, 108},
- {8.199999809, 96, 284, 8.800000191, 111},
- {7.900000095, 83, 603, 9.5, 182},
- {10.30000019, 130, 686, 8.699999809, 129},
- {7.400000095, 145, 345, 11.19999981, 158},
- {9.600000381, 112, 1357, 9.699999809, 186},
- {9.300000191, 131, 544, 9.600000381, 177},
- {10.60000038, 80, 205, 9.100000381, 127},
- {9.699999809, 130, 1264, 9.199999809, 179},
- {11.60000038, 140, 688, 8.300000191, 80},
- {8.100000381, 154, 354, 8.399999619, 103},
- {9.800000191, 118, 1632, 9.399999619, 101},
- {7.400000095, 94, 348, 9.800000191, 117},
- {9.399999619, 119, 370, 10.39999962, 88},
- {11.19999981, 153, 648, 9.899999619, 78},
- {9.100000381, 116, 366, 9.199999809, 102},
- {10.5, 97, 540, 10.30000019, 95},
- {11.89999962, 176, 680, 8.899999619, 80},
- {8.399999619, 75, 345, 9.600000381, 92},
- {5, 134, 525, 10.30000019, 126},
- {9.800000191, 161, 870, 10.39999962, 108},
- {9.800000191, 111, 669, 9.699999809, 77},
- {10.80000019, 114, 452, 9.600000381, 60},
- {10.10000038, 142, 430, 10.69999981, 71},
- {10.89999962, 238, 822, 10.30000019, 86},
- {9.199999809, 78, 190, 10.69999981, 93},
- {8.300000191, 196, 867, 9.600000381, 106},
- {7.300000191, 125, 969, 10.5, 162},
- {9.399999619, 82, 499, 7.699999809, 95},
- {9.399999619, 125, 925, 10.19999981, 91},
- {9.800000191, 129, 353, 9.899999619, 52},
- {3.599999905, 84, 288, 8.399999619, 110},
- {8.399999619, 183, 718, 10.39999962, 69},
- {10.80000019, 119, 540, 9.199999809, 57},
- {10.10000038, 180, 668, 13, 106},
- {9, 82, 347, 8.800000191, 40},
- {10, 71, 345, 9.199999809, 50},
- {11.30000019, 118, 463, 7.800000191, 35},
- {11.30000019, 121, 728, 8.199999809, 86},
- {12.80000019, 68, 383, 7.400000095, 57},
- {10, 112, 316, 10.39999962, 57},
- {6.699999809, 109, 388, 8.899999619, 94}
- };
-
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
- System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cached dataset usage example started.");
+ System.out.println(">>> Linear regression model with Min Max Scaling preprocessor over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new minmaxscaling trainer object.");
- MinMaxScalerTrainer<Integer, Vector> normalizationTrainer = new MinMaxScalerTrainer<>();
+ System.out.println(">>> Create new MinMaxScaler trainer object.");
+ MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>();
- System.out.println(">>> Perform the training to get the minmaxscaling preprocessor.");
- IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
+ System.out.println(">>> Perform the training to get the MinMaxScaler preprocessor.");
+ IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit(
ignite,
dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- }
+ (k, v) -> v.copyOfRange(1, v.size())
);
System.out.println(">>> Create new linear regression trainer object.");
@@ -155,7 +96,7 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
}
System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model with minmaxscaling preprocessor over cache based dataset usage example completed.");
+ System.out.println(">>> Linear regression model with MinMaxScaler preprocessor over cache based dataset usage example completed.");
}
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index 1e9bd5a..1dad08b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -17,16 +17,16 @@
package org.apache.ignite.examples.ml.regression.linear;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
@@ -49,72 +49,16 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
public class LinearRegressionSGDTrainerExample {
- /** */
- private static final double[][] data = {
- {8, 78, 284, 9.100000381, 109},
- {9.300000191, 68, 433, 8.699999809, 144},
- {7.5, 70, 739, 7.199999809, 113},
- {8.899999619, 96, 1792, 8.899999619, 97},
- {10.19999981, 74, 477, 8.300000191, 206},
- {8.300000191, 111, 362, 10.89999962, 124},
- {8.800000191, 77, 671, 10, 152},
- {8.800000191, 168, 636, 9.100000381, 162},
- {10.69999981, 82, 329, 8.699999809, 150},
- {11.69999981, 89, 634, 7.599999905, 134},
- {8.5, 149, 631, 10.80000019, 292},
- {8.300000191, 60, 257, 9.5, 108},
- {8.199999809, 96, 284, 8.800000191, 111},
- {7.900000095, 83, 603, 9.5, 182},
- {10.30000019, 130, 686, 8.699999809, 129},
- {7.400000095, 145, 345, 11.19999981, 158},
- {9.600000381, 112, 1357, 9.699999809, 186},
- {9.300000191, 131, 544, 9.600000381, 177},
- {10.60000038, 80, 205, 9.100000381, 127},
- {9.699999809, 130, 1264, 9.199999809, 179},
- {11.60000038, 140, 688, 8.300000191, 80},
- {8.100000381, 154, 354, 8.399999619, 103},
- {9.800000191, 118, 1632, 9.399999619, 101},
- {7.400000095, 94, 348, 9.800000191, 117},
- {9.399999619, 119, 370, 10.39999962, 88},
- {11.19999981, 153, 648, 9.899999619, 78},
- {9.100000381, 116, 366, 9.199999809, 102},
- {10.5, 97, 540, 10.30000019, 95},
- {11.89999962, 176, 680, 8.899999619, 80},
- {8.399999619, 75, 345, 9.600000381, 92},
- {5, 134, 525, 10.30000019, 126},
- {9.800000191, 161, 870, 10.39999962, 108},
- {9.800000191, 111, 669, 9.699999809, 77},
- {10.80000019, 114, 452, 9.600000381, 60},
- {10.10000038, 142, 430, 10.69999981, 71},
- {10.89999962, 238, 822, 10.30000019, 86},
- {9.199999809, 78, 190, 10.69999981, 93},
- {8.300000191, 196, 867, 9.600000381, 106},
- {7.300000191, 125, 969, 10.5, 162},
- {9.399999619, 82, 499, 7.699999809, 95},
- {9.399999619, 125, 925, 10.19999981, 91},
- {9.800000191, 129, 353, 9.899999619, 52},
- {3.599999905, 84, 288, 8.399999619, 110},
- {8.399999619, 183, 718, 10.39999962, 69},
- {10.80000019, 119, 540, 9.199999809, 57},
- {10.10000038, 180, 668, 13, 106},
- {9, 82, 347, 8.800000191, 40},
- {10, 71, 345, 9.199999809, 50},
- {11.30000019, 118, 463, 7.800000191, 35},
- {11.30000019, 121, 728, 8.199999809, 86},
- {12.80000019, 68, 383, 7.400000095, 57},
- {10, 112, 316, 10.39999962, 57},
- {6.699999809, 109, 388, 8.899999619, 94}
- };
-
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
@@ -127,8 +71,8 @@ public class LinearRegressionSGDTrainerExample {
LinearRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
);
System.out.println(">>> Linear regression model: " + mdl);
@@ -137,13 +81,13 @@ public class LinearRegressionSGDTrainerExample {
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(inputs);
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index 15330d0..52ee330 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -17,6 +17,7 @@
package org.apache.ignite.examples.ml.regression.logistic.binary;
+import java.io.FileNotFoundException;
import java.util.Arrays;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
@@ -24,9 +25,9 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
@@ -50,14 +51,15 @@ import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDT
*/
public class LogisticRegressionSGDTrainerExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
System.out.println(">>> Create new logistic regression trainer object.");
LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
@@ -75,8 +77,8 @@ public class LogisticRegressionSGDTrainerExample {
LogisticRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
);
System.out.println(">>> Logistic regression model: " + mdl);
@@ -87,13 +89,13 @@ public class LogisticRegressionSGDTrainerExample {
// Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
int[][] confusionMtx = {{0, 0}, {0, 0}};
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(inputs);
totalAmount++;
if(groundTruth != prediction)
@@ -119,108 +121,4 @@ public class LogisticRegressionSGDTrainerExample {
System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
}
}
-
- /** The 1st and 2nd classes from the Iris dataset. */
- private static final double[][] data = {
- {0, 5.1, 3.5, 1.4, 0.2},
- {0, 4.9, 3, 1.4, 0.2},
- {0, 4.7, 3.2, 1.3, 0.2},
- {0, 4.6, 3.1, 1.5, 0.2},
- {0, 5, 3.6, 1.4, 0.2},
- {0, 5.4, 3.9, 1.7, 0.4},
- {0, 4.6, 3.4, 1.4, 0.3},
- {0, 5, 3.4, 1.5, 0.2},
- {0, 4.4, 2.9, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5.4, 3.7, 1.5, 0.2},
- {0, 4.8, 3.4, 1.6, 0.2},
- {0, 4.8, 3, 1.4, 0.1},
- {0, 4.3, 3, 1.1, 0.1},
- {0, 5.8, 4, 1.2, 0.2},
- {0, 5.7, 4.4, 1.5, 0.4},
- {0, 5.4, 3.9, 1.3, 0.4},
- {0, 5.1, 3.5, 1.4, 0.3},
- {0, 5.7, 3.8, 1.7, 0.3},
- {0, 5.1, 3.8, 1.5, 0.3},
- {0, 5.4, 3.4, 1.7, 0.2},
- {0, 5.1, 3.7, 1.5, 0.4},
- {0, 4.6, 3.6, 1, 0.2},
- {0, 5.1, 3.3, 1.7, 0.5},
- {0, 4.8, 3.4, 1.9, 0.2},
- {0, 5, 3, 1.6, 0.2},
- {0, 5, 3.4, 1.6, 0.4},
- {0, 5.2, 3.5, 1.5, 0.2},
- {0, 5.2, 3.4, 1.4, 0.2},
- {0, 4.7, 3.2, 1.6, 0.2},
- {0, 4.8, 3.1, 1.6, 0.2},
- {0, 5.4, 3.4, 1.5, 0.4},
- {0, 5.2, 4.1, 1.5, 0.1},
- {0, 5.5, 4.2, 1.4, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 5, 3.2, 1.2, 0.2},
- {0, 5.5, 3.5, 1.3, 0.2},
- {0, 4.9, 3.1, 1.5, 0.1},
- {0, 4.4, 3, 1.3, 0.2},
- {0, 5.1, 3.4, 1.5, 0.2},
- {0, 5, 3.5, 1.3, 0.3},
- {0, 4.5, 2.3, 1.3, 0.3},
- {0, 4.4, 3.2, 1.3, 0.2},
- {0, 5, 3.5, 1.6, 0.6},
- {0, 5.1, 3.8, 1.9, 0.4},
- {0, 4.8, 3, 1.4, 0.3},
- {0, 5.1, 3.8, 1.6, 0.2},
- {0, 4.6, 3.2, 1.4, 0.2},
- {0, 5.3, 3.7, 1.5, 0.2},
- {0, 5, 3.3, 1.4, 0.2},
- {1, 7, 3.2, 4.7, 1.4},
- {1, 6.4, 3.2, 4.5, 1.5},
- {1, 6.9, 3.1, 4.9, 1.5},
- {1, 5.5, 2.3, 4, 1.3},
- {1, 6.5, 2.8, 4.6, 1.5},
- {1, 5.7, 2.8, 4.5, 1.3},
- {1, 6.3, 3.3, 4.7, 1.6},
- {1, 4.9, 2.4, 3.3, 1},
- {1, 6.6, 2.9, 4.6, 1.3},
- {1, 5.2, 2.7, 3.9, 1.4},
- {1, 5, 2, 3.5, 1},
- {1, 5.9, 3, 4.2, 1.5},
- {1, 6, 2.2, 4, 1},
- {1, 6.1, 2.9, 4.7, 1.4},
- {1, 5.6, 2.9, 3.6, 1.3},
- {1, 6.7, 3.1, 4.4, 1.4},
- {1, 5.6, 3, 4.5, 1.5},
- {1, 5.8, 2.7, 4.1, 1},
- {1, 6.2, 2.2, 4.5, 1.5},
- {1, 5.6, 2.5, 3.9, 1.1},
- {1, 5.9, 3.2, 4.8, 1.8},
- {1, 6.1, 2.8, 4, 1.3},
- {1, 6.3, 2.5, 4.9, 1.5},
- {1, 6.1, 2.8, 4.7, 1.2},
- {1, 6.4, 2.9, 4.3, 1.3},
- {1, 6.6, 3, 4.4, 1.4},
- {1, 6.8, 2.8, 4.8, 1.4},
- {1, 6.7, 3, 5, 1.7},
- {1, 6, 2.9, 4.5, 1.5},
- {1, 5.7, 2.6, 3.5, 1},
- {1, 5.5, 2.4, 3.8, 1.1},
- {1, 5.5, 2.4, 3.7, 1},
- {1, 5.8, 2.7, 3.9, 1.2},
- {1, 6, 2.7, 5.1, 1.6},
- {1, 5.4, 3, 4.5, 1.5},
- {1, 6, 3.4, 4.5, 1.6},
- {1, 6.7, 3.1, 4.7, 1.5},
- {1, 6.3, 2.3, 4.4, 1.3},
- {1, 5.6, 3, 4.1, 1.3},
- {1, 5.5, 2.5, 4, 1.3},
- {1, 5.5, 2.6, 4.4, 1.2},
- {1, 6.1, 3, 4.6, 1.4},
- {1, 5.8, 2.6, 4, 1.2},
- {1, 5, 2.3, 3.3, 1},
- {1, 5.6, 2.7, 4.2, 1.3},
- {1, 5.7, 3, 4.2, 1.2},
- {1, 5.7, 2.9, 4.2, 1.3},
- {1, 6.2, 2.9, 4.3, 1.3},
- {1, 5.1, 2.5, 3, 1.1},
- {1, 5.7, 2.8, 4.1, 1.3},
- };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
index ff2761a..962fdac 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/multiclass/LogRegressionMultiClassClassificationExample.java
@@ -17,6 +17,7 @@
package org.apache.ignite.examples.ml.regression.logistic.multiclass;
+import java.io.FileNotFoundException;
import java.util.Arrays;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
@@ -24,11 +25,10 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
@@ -54,14 +54,15 @@ import org.apache.ignite.ml.regressions.logistic.multiclass.LogRegressionMultiCl
*/
public class LogRegressionMultiClassClassificationExample {
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Logistic Regression Multi-class classification model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new TestCache(ignite).getVectors(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION);
LogRegressionMultiClassTrainer<?> trainer = new LogRegressionMultiClassTrainer<>()
.withUpdatesStgy(new UpdatesStrategy<>(
@@ -77,10 +78,7 @@ public class LogRegressionMultiClassClassificationExample {
LogRegressionMultiClassModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- },
+ (k, v) -> v.copyOfRange(1, v.size()),
(k, v) -> v.get(0)
);
@@ -92,10 +90,7 @@ public class LogRegressionMultiClassClassificationExample {
IgniteBiFunction<Integer, Vector, Vector> preprocessor = normalizationTrainer.fit(
ignite,
dataCache,
- (k, v) -> {
- double[] arr = v.asArray();
- return VectorUtils.of(Arrays.copyOfRange(arr, 1, arr.length));
- }
+ (k, v) -> v.copyOfRange(1, v.size())
);
LogRegressionMultiClassModel mdlWithNormalization = trainer.fit(
@@ -105,7 +100,7 @@ public class LogRegressionMultiClassClassificationExample {
(k, v) -> v.get(0)
);
- System.out.println(">>> Logistic Regression Multi-class model with minmaxscaling");
+ System.out.println(">>> Logistic Regression Multi-class model with normalization");
System.out.println(mdlWithNormalization.toString());
System.out.println(">>> ----------------------------------------------------------------");
@@ -122,12 +117,12 @@ public class LogRegressionMultiClassClassificationExample {
try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
for (Cache.Entry<Integer, Vector> observation : observations) {
- double[] val = observation.getValue().asArray();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(new DenseVector(inputs));
- double predictionWithNormalization = mdlWithNormalization.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(inputs);
+ double predictionWithNormalization = mdlWithNormalization.apply(inputs);
totalAmount++;
@@ -140,7 +135,7 @@ public class LogRegressionMultiClassClassificationExample {
confusionMtx[idx1][idx2]++;
- // Collect data for model with minmaxscaling
+ // Collect data for model with normalization
if(groundTruth != predictionWithNormalization)
amountOfErrorsWithNormalization++;
@@ -166,127 +161,4 @@ public class LogRegressionMultiClassClassificationExample {
}
}
}
-
- /** The preprocessed Glass dataset from the Machine Learning Repository https://archive.ics.uci.edu/ml/datasets/Glass+Identification
- * There are 3 classes with labels: 1 {building_windows_float_processed}, 3 {vehicle_windows_float_processed}, 7 {headlamps}.
- * Feature names: 'Na-Sodium', 'Mg-Magnesium', 'Al-Aluminum', 'Ba-Barium', 'Fe-Iron'.
- */
- private static final double[][] data = {
- {1, 1.52101, 4.49, 1.10, 0.00, 0.00},
- {1, 1.51761, 3.60, 1.36, 0.00, 0.00},
- {1, 1.51618, 3.55, 1.54, 0.00, 0.00},
- {1, 1.51766, 3.69, 1.29, 0.00, 0.00},
- {1, 1.51742, 3.62, 1.24, 0.00, 0.00},
- {1, 1.51596, 3.61, 1.62, 0.00, 0.26},
- {1, 1.51743, 3.60, 1.14, 0.00, 0.00},
- {1, 1.51756, 3.61, 1.05, 0.00, 0.00},
- {1, 1.51918, 3.58, 1.37, 0.00, 0.00},
- {1, 1.51755, 3.60, 1.36, 0.00, 0.11},
- {1, 1.51571, 3.46, 1.56, 0.00, 0.24},
- {1, 1.51763, 3.66, 1.27, 0.00, 0.00},
- {1, 1.51589, 3.43, 1.40, 0.00, 0.24},
- {1, 1.51748, 3.56, 1.27, 0.00, 0.17},
- {1, 1.51763, 3.59, 1.31, 0.00, 0.00},
- {1, 1.51761, 3.54, 1.23, 0.00, 0.00},
- {1, 1.51784, 3.67, 1.16, 0.00, 0.00},
- {1, 1.52196, 3.85, 0.89, 0.00, 0.00},
- {1, 1.51911, 3.73, 1.18, 0.00, 0.00},
- {1, 1.51735, 3.54, 1.69, 0.00, 0.07},
- {1, 1.51750, 3.55, 1.49, 0.00, 0.19},
- {1, 1.51966, 3.75, 0.29, 0.00, 0.00},
- {1, 1.51736, 3.62, 1.29, 0.00, 0.00},
- {1, 1.51751, 3.57, 1.35, 0.00, 0.00},
- {1, 1.51720, 3.50, 1.15, 0.00, 0.00},
- {1, 1.51764, 3.54, 1.21, 0.00, 0.00},
- {1, 1.51793, 3.48, 1.41, 0.00, 0.00},
- {1, 1.51721, 3.48, 1.33, 0.00, 0.00},
- {1, 1.51768, 3.52, 1.43, 0.00, 0.00},
- {1, 1.51784, 3.49, 1.28, 0.00, 0.00},
- {1, 1.51768, 3.56, 1.30, 0.00, 0.14},
- {1, 1.51747, 3.50, 1.14, 0.00, 0.00},
- {1, 1.51775, 3.48, 1.23, 0.09, 0.22},
- {1, 1.51753, 3.47, 1.38, 0.00, 0.06},
- {1, 1.51783, 3.54, 1.34, 0.00, 0.00},
- {1, 1.51567, 3.45, 1.21, 0.00, 0.00},
- {1, 1.51909, 3.53, 1.32, 0.11, 0.00},
- {1, 1.51797, 3.48, 1.35, 0.00, 0.00},
- {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
- {1, 1.52213, 3.82, 0.47, 0.00, 0.00},
- {1, 1.51793, 3.50, 1.12, 0.00, 0.00},
- {1, 1.51755, 3.42, 1.20, 0.00, 0.00},
- {1, 1.51779, 3.39, 1.33, 0.00, 0.00},
- {1, 1.52210, 3.84, 0.72, 0.00, 0.00},
- {1, 1.51786, 3.43, 1.19, 0.00, 0.30},
- {1, 1.51900, 3.48, 1.35, 0.00, 0.00},
- {1, 1.51869, 3.37, 1.18, 0.00, 0.16},
- {1, 1.52667, 3.70, 0.71, 0.00, 0.10},
- {1, 1.52223, 3.77, 0.79, 0.00, 0.00},
- {1, 1.51898, 3.35, 1.23, 0.00, 0.00},
- {1, 1.52320, 3.72, 0.51, 0.00, 0.16},
- {1, 1.51926, 3.33, 1.28, 0.00, 0.11},
- {1, 1.51808, 2.87, 1.19, 0.00, 0.00},
- {1, 1.51837, 2.84, 1.28, 0.00, 0.00},
- {1, 1.51778, 2.81, 1.29, 0.00, 0.09},
- {1, 1.51769, 2.71, 1.29, 0.00, 0.24},
- {1, 1.51215, 3.47, 1.12, 0.00, 0.31},
- {1, 1.51824, 3.48, 1.29, 0.00, 0.00},
- {1, 1.51754, 3.74, 1.17, 0.00, 0.00},
- {1, 1.51754, 3.66, 1.19, 0.00, 0.11},
- {1, 1.51905, 3.62, 1.11, 0.00, 0.00},
- {1, 1.51977, 3.58, 1.32, 0.69, 0.00},
- {1, 1.52172, 3.86, 0.88, 0.00, 0.11},
- {1, 1.52227, 3.81, 0.78, 0.00, 0.00},
- {1, 1.52172, 3.74, 0.90, 0.00, 0.07},
- {1, 1.52099, 3.59, 1.12, 0.00, 0.00},
- {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
- {1, 1.52152, 3.65, 0.87, 0.00, 0.17},
- {1, 1.52152, 3.58, 0.90, 0.00, 0.16},
- {1, 1.52300, 3.58, 0.82, 0.00, 0.03},
- {3, 1.51769, 3.66, 1.11, 0.00, 0.00},
- {3, 1.51610, 3.53, 1.34, 0.00, 0.00},
- {3, 1.51670, 3.57, 1.38, 0.00, 0.10},
- {3, 1.51643, 3.52, 1.35, 0.00, 0.00},
- {3, 1.51665, 3.45, 1.76, 0.00, 0.17},
- {3, 1.52127, 3.90, 0.83, 0.00, 0.00},
- {3, 1.51779, 3.65, 0.65, 0.00, 0.00},
- {3, 1.51610, 3.40, 1.22, 0.00, 0.00},
- {3, 1.51694, 3.58, 1.31, 0.00, 0.00},
- {3, 1.51646, 3.40, 1.26, 0.00, 0.00},
- {3, 1.51655, 3.39, 1.28, 0.00, 0.00},
- {3, 1.52121, 3.76, 0.58, 0.00, 0.00},
- {3, 1.51776, 3.41, 1.52, 0.00, 0.00},
- {3, 1.51796, 3.36, 1.63, 0.00, 0.09},
- {3, 1.51832, 3.34, 1.54, 0.00, 0.00},
- {3, 1.51934, 3.54, 0.75, 0.15, 0.24},
- {3, 1.52211, 3.78, 0.91, 0.00, 0.37},
- {7, 1.51131, 3.20, 1.81, 1.19, 0.00},
- {7, 1.51838, 3.26, 2.22, 1.63, 0.00},
- {7, 1.52315, 3.34, 1.23, 0.00, 0.00},
- {7, 1.52247, 2.20, 2.06, 0.00, 0.00},
- {7, 1.52365, 1.83, 1.31, 1.68, 0.00},
- {7, 1.51613, 1.78, 1.79, 0.76, 0.00},
- {7, 1.51602, 0.00, 2.38, 0.64, 0.09},
- {7, 1.51623, 0.00, 2.79, 0.40, 0.09},
- {7, 1.51719, 0.00, 2.00, 1.59, 0.08},
- {7, 1.51683, 0.00, 1.98, 1.57, 0.07},
- {7, 1.51545, 0.00, 2.68, 0.61, 0.05},
- {7, 1.51556, 0.00, 2.54, 0.81, 0.01},
- {7, 1.51727, 0.00, 2.34, 0.66, 0.00},
- {7, 1.51531, 0.00, 2.66, 0.64, 0.00},
- {7, 1.51609, 0.00, 2.51, 0.53, 0.00},
- {7, 1.51508, 0.00, 2.25, 0.63, 0.00},
- {7, 1.51653, 0.00, 1.19, 0.00, 0.00},
- {7, 1.51514, 0.00, 2.42, 0.56, 0.00},
- {7, 1.51658, 0.00, 1.99, 1.71, 0.00},
- {7, 1.51617, 0.00, 2.27, 0.67, 0.00},
- {7, 1.51732, 0.00, 1.80, 1.55, 0.00},
- {7, 1.51645, 0.00, 1.87, 1.38, 0.00},
- {7, 1.51831, 0.00, 1.82, 2.88, 0.00},
- {7, 1.51640, 0.00, 2.74, 0.54, 0.00},
- {7, 1.51623, 0.00, 2.88, 1.06, 0.00},
- {7, 1.51685, 0.00, 1.99, 1.59, 0.00},
- {7, 1.52065, 0.00, 2.02, 1.64, 0.00},
- {7, 1.51651, 0.00, 1.94, 1.57, 0.00},
- {7, 1.51711, 0.00, 2.08, 1.67, 0.00},
- };
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index 25ce156..552bcd2 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -45,7 +45,7 @@ public class CrossValidationExample {
*
* @param args Command line arguments, none required.
*/
- public static void main(String... args) throws InterruptedException {
+ public static void main(String... args) {
System.out.println(">>> Cross validation score calculator example started.");
// Start ignite grid.
http://git-wip-us.apache.org/repos/asf/ignite/blob/370cd3e1/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
index 8b104f5..4bfd993 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
@@ -17,16 +17,16 @@
package org.apache.ignite.examples.ml.selection.split;
-import java.util.Arrays;
+import java.io.FileNotFoundException;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
-import org.apache.ignite.examples.ml.util.TestCache;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
@@ -47,78 +47,22 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
* further.</p>
*/
public class TrainTestDatasetSplitterExample {
- /** */
- private static final double[][] data = {
- {8, 78, 284, 9.100000381, 109},
- {9.300000191, 68, 433, 8.699999809, 144},
- {7.5, 70, 739, 7.199999809, 113},
- {8.899999619, 96, 1792, 8.899999619, 97},
- {10.19999981, 74, 477, 8.300000191, 206},
- {8.300000191, 111, 362, 10.89999962, 124},
- {8.800000191, 77, 671, 10, 152},
- {8.800000191, 168, 636, 9.100000381, 162},
- {10.69999981, 82, 329, 8.699999809, 150},
- {11.69999981, 89, 634, 7.599999905, 134},
- {8.5, 149, 631, 10.80000019, 292},
- {8.300000191, 60, 257, 9.5, 108},
- {8.199999809, 96, 284, 8.800000191, 111},
- {7.900000095, 83, 603, 9.5, 182},
- {10.30000019, 130, 686, 8.699999809, 129},
- {7.400000095, 145, 345, 11.19999981, 158},
- {9.600000381, 112, 1357, 9.699999809, 186},
- {9.300000191, 131, 544, 9.600000381, 177},
- {10.60000038, 80, 205, 9.100000381, 127},
- {9.699999809, 130, 1264, 9.199999809, 179},
- {11.60000038, 140, 688, 8.300000191, 80},
- {8.100000381, 154, 354, 8.399999619, 103},
- {9.800000191, 118, 1632, 9.399999619, 101},
- {7.400000095, 94, 348, 9.800000191, 117},
- {9.399999619, 119, 370, 10.39999962, 88},
- {11.19999981, 153, 648, 9.899999619, 78},
- {9.100000381, 116, 366, 9.199999809, 102},
- {10.5, 97, 540, 10.30000019, 95},
- {11.89999962, 176, 680, 8.899999619, 80},
- {8.399999619, 75, 345, 9.600000381, 92},
- {5, 134, 525, 10.30000019, 126},
- {9.800000191, 161, 870, 10.39999962, 108},
- {9.800000191, 111, 669, 9.699999809, 77},
- {10.80000019, 114, 452, 9.600000381, 60},
- {10.10000038, 142, 430, 10.69999981, 71},
- {10.89999962, 238, 822, 10.30000019, 86},
- {9.199999809, 78, 190, 10.69999981, 93},
- {8.300000191, 196, 867, 9.600000381, 106},
- {7.300000191, 125, 969, 10.5, 162},
- {9.399999619, 82, 499, 7.699999809, 95},
- {9.399999619, 125, 925, 10.19999981, 91},
- {9.800000191, 129, 353, 9.899999619, 52},
- {3.599999905, 84, 288, 8.399999619, 110},
- {8.399999619, 183, 718, 10.39999962, 69},
- {10.80000019, 119, 540, 9.199999809, 57},
- {10.10000038, 180, 668, 13, 106},
- {9, 82, 347, 8.800000191, 40},
- {10, 71, 345, 9.199999809, 50},
- {11.30000019, 118, 463, 7.800000191, 35},
- {11.30000019, 121, 728, 8.199999809, 86},
- {12.80000019, 68, 383, 7.400000095, 57},
- {10, 112, 316, 10.39999962, 57},
- {6.699999809, 109, 388, 8.899999619, 94}
- };
-
/** Run example. */
- public static void main(String[] args) throws InterruptedException {
+ public static void main(String[] args) throws FileNotFoundException {
System.out.println();
System.out.println(">>> Linear regression model over cache based dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(data);
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
System.out.println(">>> Create new linear regression trainer object.");
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Create new training dataset splitter object.");
- TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
+ TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
.split(0.75);
System.out.println(">>> Perform the training to get the model.");
@@ -126,8 +70,8 @@ public class TrainTestDatasetSplitterExample {
ignite,
dataCache,
split.getTrainFilter(),
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0)
);
System.out.println(">>> Linear regression model: " + mdl);
@@ -136,16 +80,16 @@ public class TrainTestDatasetSplitterExample {
System.out.println(">>> | Prediction\t| Ground Truth\t|");
System.out.println(">>> ---------------------------------");
- ScanQuery<Integer, double[]> qry = new ScanQuery<>();
+ ScanQuery<Integer, Vector> qry = new ScanQuery<>();
qry.setFilter(split.getTestFilter());
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(qry)) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(qry)) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.apply(new DenseVector(inputs));
+ double prediction = mdl.apply(inputs);
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
}