You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/04/13 09:50:07 UTC
ignite git commit: IGNITE-7829: Adopt kNN regression example to the
new Partitioned Dataset
Repository: ignite
Updated Branches:
refs/heads/master 9be3357c4 -> 8550d61b6
IGNITE-7829: Adopt kNN regression example to the new Partitioned Dataset
this closes #3798
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/8550d61b
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/8550d61b
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/8550d61b
Branch: refs/heads/master
Commit: 8550d61b6b39625579eb7f69f4d1218b78f7cc5b
Parents: 9be3357
Author: zaleslaw <za...@gmail.com>
Authored: Fri Apr 13 12:49:56 2018 +0300
Committer: YuriBabak <y....@gmail.com>
Committed: Fri Apr 13 12:49:56 2018 +0300
----------------------------------------------------------------------
.../ml/knn/KNNClassificationExample.java | 4 +-
.../examples/ml/knn/KNNRegressionExample.java | 310 +++++++++++++++++++
.../java/org/apache/ignite/ml/knn/KNNUtils.java | 10 +-
.../classification/KNNClassificationModel.java | 9 +-
.../ml/knn/partitions/KNNPartitionContext.java | 28 --
.../ignite/ml/knn/partitions/package-info.java | 22 --
.../ml/knn/regression/KNNRegressionModel.java | 7 +-
.../partition/LabelPartitionContext.java | 28 --
.../LabelPartitionDataBuilderOnHeap.java | 1 -
.../svm/SVMLinearBinaryClassificationModel.java | 3 +
.../SVMLinearBinaryClassificationTrainer.java | 9 +-
.../SVMLinearMultiClassClassificationModel.java | 3 +
...VMLinearMultiClassClassificationTrainer.java | 8 +-
.../ignite/ml/svm/SVMPartitionContext.java | 28 --
.../org/apache/ignite/ml/knn/BaseKNNTest.java | 89 ------
.../ignite/ml/knn/KNNClassificationTest.java | 110 +++----
.../apache/ignite/ml/knn/KNNRegressionTest.java | 104 +++----
.../ignite/ml/knn/LabeledDatasetHelper.java | 87 ++++++
.../ignite/ml/knn/LabeledDatasetTest.java | 2 +-
19 files changed, 536 insertions(+), 326 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 39a8431..15375a1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -80,7 +80,7 @@ public class KNNClassificationExample {
double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs));
totalAmount++;
- if(groundTruth != prediction)
+ if (groundTruth != prediction)
amountOfErrors++;
System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
@@ -89,7 +89,7 @@ public class KNNClassificationExample {
System.out.println(">>> ---------------------------------");
System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
}
});
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
new file mode 100644
index 0000000..76a07cd
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.knn;
+
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
+import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run kNN regression trainer over distributed dataset.
+ *
+ * @see KNNClassificationTrainer
+ */
+public class KNNRegressionExample {
+ /** Run example. */
+ public static void main(String[] args) throws InterruptedException {
+ System.out.println();
+ System.out.println(">>> kNN regression over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ KNNRegressionExample.class.getSimpleName(), () -> {
+ IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+
+ KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
+ new CacheBasedDatasetBuilder<>(ignite, dataCache),
+ (k, v) -> Arrays.copyOfRange(v, 1, v.length),
+ (k, v) -> v[0]
+ ).withK(5)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(KNNStrategy.WEIGHTED);
+
+ int totalAmount = 0;
+ // Calculate mean squared error (MSE)
+ double mse = 0.0;
+ // Calculate mean absolute error (MAE)
+ double mae = 0.0;
+
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = knnMdl.apply(new DenseLocalOnHeapVector(inputs));
+
+ mse += Math.pow(prediction - groundTruth, 2.0);
+ mae += Math.abs(prediction - groundTruth);
+
+ totalAmount++;
+ }
+
+ mse = mse / totalAmount;
+ System.out.println("\n>>> Mean squared error (MSE) " + mse);
+
+ mae = mae / totalAmount;
+ System.out.println("\n>>> Mean absolute error (MAE) " + mae);
+ }
+ });
+
+ igniteThread.start();
+ igniteThread.join();
+ }
+ }
+
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int i = 0; i < data.length; i++)
+ cache.put(i, data[i]);
+
+ return cache;
+ }
+
+ /** The Iris dataset. */
+ private static final double[][] data = {
+ {199, 125, 256, 6000, 256, 16, 128},
+ {253, 29, 8000, 32000, 32, 8, 32},
+ {132, 29, 8000, 16000, 32, 8, 16},
+ {290, 26, 8000, 32000, 64, 8, 32},
+ {381, 23, 16000, 32000, 64, 16, 32},
+ {749, 23, 16000, 64000, 64, 16, 32},
+ {1238, 23, 32000, 64000, 128, 32, 64},
+ {23, 400, 1000, 3000, 0, 1, 2},
+ {24, 400, 512, 3500, 4, 1, 6},
+ {70, 60, 2000, 8000, 65, 1, 8},
+ {117, 50, 4000, 16000, 65, 1, 8},
+ {15, 350, 64, 64, 0, 1, 4},
+ {64, 200, 512, 16000, 0, 4, 32},
+ {23, 167, 524, 2000, 8, 4, 15},
+ {29, 143, 512, 5000, 0, 7, 32},
+ {22, 143, 1000, 2000, 0, 5, 16},
+ {124, 110, 5000, 5000, 142, 8, 64},
+ {35, 143, 1500, 6300, 0, 5, 32},
+ {39, 143, 3100, 6200, 0, 5, 20},
+ {40, 143, 2300, 6200, 0, 6, 64},
+ {45, 110, 3100, 6200, 0, 6, 64},
+ {28, 320, 128, 6000, 0, 1, 12},
+ {21, 320, 512, 2000, 4, 1, 3},
+ {28, 320, 256, 6000, 0, 1, 6},
+ {22, 320, 256, 3000, 4, 1, 3},
+ {28, 320, 512, 5000, 4, 1, 5},
+ {27, 320, 256, 5000, 4, 1, 6},
+ {102, 25, 1310, 2620, 131, 12, 24},
+ {74, 50, 2620, 10480, 30, 12, 24},
+ {138, 56, 5240, 20970, 30, 12, 24},
+ {136, 64, 5240, 20970, 30, 12, 24},
+ {23, 50, 500, 2000, 8, 1, 4},
+ {29, 50, 1000, 4000, 8, 1, 5},
+ {44, 50, 2000, 8000, 8, 1, 5},
+ {30, 50, 1000, 4000, 8, 3, 5},
+ {41, 50, 1000, 8000, 8, 3, 5},
+ {74, 50, 2000, 16000, 8, 3, 5},
+ {54, 133, 1000, 12000, 9, 3, 12},
+ {41, 133, 1000, 8000, 9, 3, 12},
+ {18, 810, 512, 512, 8, 1, 1},
+ {28, 810, 1000, 5000, 0, 1, 1},
+ {36, 320, 512, 8000, 4, 1, 5},
+ {38, 200, 512, 8000, 8, 1, 8},
+ {34, 700, 384, 8000, 0, 1, 1},
+ {19, 700, 256, 2000, 0, 1, 1},
+ {72, 140, 1000, 16000, 16, 1, 3},
+ {36, 200, 1000, 8000, 0, 1, 2},
+ {30, 110, 1000, 4000, 16, 1, 2},
+ {56, 110, 1000, 12000, 16, 1, 2},
+ {42, 220, 1000, 8000, 16, 1, 2},
+ {34, 800, 256, 8000, 0, 1, 4},
+ {19, 125, 512, 1000, 0, 8, 20},
+ {75, 75, 2000, 8000, 64, 1, 38},
+ {113, 75, 2000, 16000, 64, 1, 38},
+ {157, 75, 2000, 16000, 128, 1, 38},
+ {18, 90, 256, 1000, 0, 3, 10},
+ {20, 105, 256, 2000, 0, 3, 10},
+ {28, 105, 1000, 4000, 0, 3, 24},
+ {33, 105, 2000, 4000, 8, 3, 19},
+ {47, 75, 2000, 8000, 8, 3, 24},
+ {54, 75, 3000, 8000, 8, 3, 48},
+ {20, 175, 256, 2000, 0, 3, 24},
+ {23, 300, 768, 3000, 0, 6, 24},
+ {25, 300, 768, 3000, 6, 6, 24},
+ {52, 300, 768, 12000, 6, 6, 24},
+ {27, 300, 768, 4500, 0, 1, 24},
+ {50, 300, 384, 12000, 6, 1, 24},
+ {18, 300, 192, 768, 6, 6, 24},
+ {53, 180, 768, 12000, 6, 1, 31},
+ {23, 330, 1000, 3000, 0, 2, 4},
+ {30, 300, 1000, 4000, 8, 3, 64},
+ {73, 300, 1000, 16000, 8, 2, 112},
+ {20, 330, 1000, 2000, 0, 1, 2},
+ {25, 330, 1000, 4000, 0, 3, 6},
+ {28, 140, 2000, 4000, 0, 3, 6},
+ {29, 140, 2000, 4000, 0, 4, 8},
+ {32, 140, 2000, 4000, 8, 1, 20},
+ {175, 140, 2000, 32000, 32, 1, 20},
+ {57, 140, 2000, 8000, 32, 1, 54},
+ {181, 140, 2000, 32000, 32, 1, 54},
+ {32, 140, 2000, 4000, 8, 1, 20},
+ {82, 57, 4000, 16000, 1, 6, 12},
+ {171, 57, 4000, 24000, 64, 12, 16},
+ {361, 26, 16000, 32000, 64, 16, 24},
+ {350, 26, 16000, 32000, 64, 8, 24},
+ {220, 26, 8000, 32000, 0, 8, 24},
+ {113, 26, 8000, 16000, 0, 8, 16},
+ {15, 480, 96, 512, 0, 1, 1},
+ {21, 203, 1000, 2000, 0, 1, 5},
+ {35, 115, 512, 6000, 16, 1, 6},
+ {18, 1100, 512, 1500, 0, 1, 1},
+ {20, 1100, 768, 2000, 0, 1, 1},
+ {20, 600, 768, 2000, 0, 1, 1},
+ {28, 400, 2000, 4000, 0, 1, 1},
+ {45, 400, 4000, 8000, 0, 1, 1},
+ {18, 900, 1000, 1000, 0, 1, 2},
+ {17, 900, 512, 1000, 0, 1, 2},
+ {26, 900, 1000, 4000, 4, 1, 2},
+ {28, 900, 1000, 4000, 8, 1, 2},
+ {28, 900, 2000, 4000, 0, 3, 6},
+ {31, 225, 2000, 4000, 8, 3, 6},
+ {42, 180, 2000, 8000, 8, 1, 6},
+ {76, 185, 2000, 16000, 16, 1, 6},
+ {76, 180, 2000, 16000, 16, 1, 6},
+ {26, 225, 1000, 4000, 2, 3, 6},
+ {59, 25, 2000, 12000, 8, 1, 4},
+ {65, 25, 2000, 12000, 16, 3, 5},
+ {101, 17, 4000, 16000, 8, 6, 12},
+ {116, 17, 4000, 16000, 32, 6, 12},
+ {18, 1500, 768, 1000, 0, 0, 0},
+ {20, 1500, 768, 2000, 0, 0, 0},
+ {20, 800, 768, 2000, 0, 0, 0},
+ {30, 50, 2000, 4000, 0, 3, 6},
+ {44, 50, 2000, 8000, 8, 3, 6},
+ {82, 50, 2000, 16000, 24, 1, 6},
+ {128, 50, 8000, 16000, 48, 1, 10},
+ {37, 100, 1000, 8000, 0, 2, 6},
+ {46, 100, 1000, 8000, 24, 2, 6},
+ {46, 100, 1000, 8000, 24, 3, 6},
+ {80, 50, 2000, 16000, 12, 3, 16},
+ {88, 50, 2000, 16000, 24, 6, 16},
+ {33, 150, 512, 4000, 0, 8, 128},
+ {46, 115, 2000, 8000, 16, 1, 3},
+ {29, 115, 2000, 4000, 2, 1, 5},
+ {53, 92, 2000, 8000, 32, 1, 6},
+ {41, 92, 2000, 8000, 4, 1, 6},
+ {86, 75, 4000, 16000, 16, 1, 6},
+ {95, 60, 4000, 16000, 32, 1, 6},
+ {107, 60, 2000, 16000, 64, 5, 8},
+ {117, 60, 4000, 16000, 64, 5, 8},
+ {119, 50, 4000, 16000, 64, 5, 10},
+ {120, 72, 4000, 16000, 64, 8, 16},
+ {48, 72, 2000, 8000, 16, 6, 8},
+ {126, 40, 8000, 16000, 32, 8, 16},
+ {266, 40, 8000, 32000, 64, 8, 24},
+ {270, 35, 8000, 32000, 64, 8, 24},
+ {426, 38, 16000, 32000, 128, 16, 32},
+ {151, 48, 4000, 24000, 32, 8, 24},
+ {267, 38, 8000, 32000, 64, 8, 24},
+ {603, 30, 16000, 32000, 256, 16, 24},
+ {19, 112, 1000, 1000, 0, 1, 4},
+ {21, 84, 1000, 2000, 0, 1, 6},
+ {26, 56, 1000, 4000, 0, 1, 6},
+ {35, 56, 2000, 6000, 0, 1, 8},
+ {41, 56, 2000, 8000, 0, 1, 8},
+ {47, 56, 4000, 8000, 0, 1, 8},
+ {62, 56, 4000, 12000, 0, 1, 8},
+ {78, 56, 4000, 16000, 0, 1, 8},
+ {80, 38, 4000, 8000, 32, 16, 32},
+ {142, 38, 8000, 16000, 64, 4, 8},
+ {281, 38, 8000, 24000, 160, 4, 8},
+ {190, 38, 4000, 16000, 128, 16, 32},
+ {21, 200, 1000, 2000, 0, 1, 2},
+ {25, 200, 1000, 4000, 0, 1, 4},
+ {67, 200, 2000, 8000, 64, 1, 5},
+ {24, 250, 512, 4000, 0, 1, 7},
+ {24, 250, 512, 4000, 0, 4, 7},
+ {64, 250, 1000, 16000, 1, 1, 8},
+ {25, 160, 512, 4000, 2, 1, 5},
+ {20, 160, 512, 2000, 2, 3, 8},
+ {29, 160, 1000, 4000, 8, 1, 14},
+ {43, 160, 1000, 8000, 16, 1, 14},
+ {53, 160, 2000, 8000, 32, 1, 13},
+ {19, 240, 512, 1000, 8, 1, 3},
+ {22, 240, 512, 2000, 8, 1, 5},
+ {31, 105, 2000, 4000, 8, 3, 8},
+ {41, 105, 2000, 6000, 16, 6, 16},
+ {47, 105, 2000, 8000, 16, 4, 14},
+ {99, 52, 4000, 16000, 32, 4, 12},
+ {67, 70, 4000, 12000, 8, 6, 8},
+ {81, 59, 4000, 12000, 32, 6, 12},
+ {149, 59, 8000, 16000, 64, 12, 24},
+ {183, 26, 8000, 24000, 32, 8, 16},
+ {275, 26, 8000, 32000, 64, 12, 16},
+ {382, 26, 8000, 32000, 128, 24, 32},
+ {56, 116, 2000, 8000, 32, 5, 28},
+ {182, 50, 2000, 32000, 24, 6, 26},
+ {227, 50, 2000, 32000, 48, 26, 52},
+ {341, 50, 2000, 32000, 112, 52, 104},
+ {360, 50, 4000, 32000, 112, 52, 104},
+ {919, 30, 8000, 64000, 96, 12, 176},
+ {978, 30, 8000, 64000, 128, 12, 176},
+ {24, 180, 262, 4000, 0, 1, 3},
+ {37, 124, 1000, 8000, 0, 1, 8},
+ {50, 98, 1000, 8000, 32, 2, 8},
+ {41, 125, 2000, 8000, 0, 2, 14},
+ {47, 480, 512, 8000, 32, 0, 0},
+ {25, 480, 1000, 4000, 0, 0, 0}
+ };
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
index 88fa70f..716eb52 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
@@ -20,7 +20,7 @@ package org.apache.ignite.ml.knn;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.knn.partitions.KNNPartitionContext;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
@@ -39,18 +39,18 @@ public class KNNUtils {
* @param lbExtractor Label extractor.
* @return Dataset.
*/
- @Nullable public static <K, V> Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
- PartitionDataBuilder<K, V, KNNPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder
+ @Nullable public static <K, V> Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> buildDataset(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder
= new LabeledDatasetPartitionDataBuilderOnHeap<>(
featureExtractor,
lbExtractor
);
- Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = null;
+ Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = null;
if (datasetBuilder != null) {
dataset = datasetBuilder.build(
- (upstream, upstreamSize) -> new KNNPartitionContext(),
+ (upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
index 373f822..693b81d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -32,7 +32,7 @@ import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.Dataset;
-import org.apache.ignite.ml.knn.partitions.KNNPartitionContext;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
@@ -44,6 +44,9 @@ import org.jetbrains.annotations.NotNull;
* kNN algorithm model to solve multi-class classification task.
*/
public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Exportable<KNNModelFormat> {
+ /** */
+ private static final long serialVersionUID = -127386523291350345L;
+
/** Amount of nearest neighbors. */
protected int k = 5;
@@ -54,13 +57,13 @@ public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Expo
protected KNNStrategy stgy = KNNStrategy.SIMPLE;
/** Dataset. */
- private Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset;
+ private Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset;
/**
* Builds the model via prepared dataset.
* @param dataset Specially prepared object to run algorithm over it.
*/
- public KNNClassificationModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) {
+ public KNNClassificationModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
this.dataset = dataset;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
deleted file mode 100644
index 0081612..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn.partitions;
-
-import java.io.Serializable;
-
-/**
- * Partition context of the kNN classification algorithm.
- */
-public class KNNPartitionContext implements Serializable {
- /** */
- private static final long serialVersionUID = -7212307112344430126L;
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
deleted file mode 100644
index 951a849..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains helper classes for kNN classification algorithms.
- */
-package org.apache.ignite.ml.knn.partitions;
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java
index cabc143..f5def43 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionModel.java
@@ -17,8 +17,8 @@
package org.apache.ignite.ml.knn.regression;
import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
-import org.apache.ignite.ml.knn.partitions.KNNPartitionContext;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
import org.apache.ignite.ml.structures.LabeledDataset;
@@ -38,11 +38,14 @@ import java.util.List;
* </ul>
*/
public class KNNRegressionModel<K,V> extends KNNClassificationModel<K,V> {
+ /** */
+ private static final long serialVersionUID = -721836321291120543L;
+
/**
* Builds the model via prepared dataset.
* @param dataset Specially prepared object to run algorithm over it.
*/
- public KNNRegressionModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) {
+ public KNNRegressionModel(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
super(dataset);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
deleted file mode 100644
index 1069ff8..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.structures.partition;
-
-import java.io.Serializable;
-
-/**
- * Base partition context.
- */
-public class LabelPartitionContext implements Serializable {
- /** */
- private static final long serialVersionUID = -7412302212344430126L;
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
index 14c053e..4fba028 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
@@ -22,7 +22,6 @@ import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.structures.LabeledDataset;
/**
* Partition data builder that builds {@link LabelPartitionDataOnHeap}.
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
index dace8c6..f806fb8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationModel.java
@@ -28,6 +28,9 @@ import org.apache.ignite.ml.math.Vector;
* Base class for SVM linear classification model.
*/
public class SVMLinearBinaryClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearBinaryClassificationModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -996984622291440226L;
+
/** Output label format. -1 and +1 for false value and raw distances from the separating hyperplane otherwise. */
private boolean isKeepingRawLabels = false;
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 7f11e20..d56848c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -18,6 +18,7 @@
package org.apache.ignite.ml.svm;
import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.Dataset;
@@ -59,15 +60,15 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT
assert datasetBuilder != null;
- PartitionDataBuilder<K, V, SVMPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+ PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
featureExtractor,
lbExtractor
);
Vector weights;
- try(Dataset<SVMPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
- (upstream, upstreamSize) -> new SVMPartitionContext(),
+ try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {
final int cols = dataset.compute(data -> data.colSize(), (a, b) -> a == null ? b : a);
@@ -90,7 +91,7 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT
}
/** */
- private Vector calculateUpdates(Vector weights, Dataset<SVMPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) {
+ private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
return dataset.compute(data -> {
Vector copiedWeights = weights.copy();
Vector deltaWeights = initializeWeightsWithZeros(weights.size());
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
index 5879ef0..bbec791 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationModel.java
@@ -29,6 +29,9 @@ import org.apache.ignite.ml.math.Vector;
/** Base class for multi-classification model for set of SVM classifiers. */
public class SVMLinearMultiClassClassificationModel implements Model<Vector, Double>, Exportable<SVMLinearMultiClassClassificationModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -667986511191350227L;
+
/** List of models associated with each class. */
private Map<Double, SVMLinearBinaryClassificationModel> models;
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index 88c342d..4e081c6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -24,12 +24,12 @@ import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.structures.partition.LabelPartitionContext;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
@@ -89,12 +89,12 @@ public class SVMLinearMultiClassClassificationTrainer
private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Double> lbExtractor) {
assert datasetBuilder != null;
- PartitionDataBuilder<K, V, LabelPartitionContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
+ PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
List<Double> res = new ArrayList<>();
- try (Dataset<LabelPartitionContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
- (upstream, upstreamSize) -> new LabelPartitionContext(),
+ try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {
final Set<Double> clsLabels = dataset.compute(data -> {
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
deleted file mode 100644
index 0aee0fb..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionContext.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.io.Serializable;
-
-/**
- * Partition context of the SVM classification algorithm.
- */
-public class SVMPartitionContext implements Serializable {
- /** */
- private static final long serialVersionUID = -7212307112344430126L;
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
deleted file mode 100644
index aeac2cf..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn;
-
-import java.io.IOException;
-import java.net.URISyntaxException;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
-
-/**
- * Base class for decision trees test.
- */
-public class BaseKNNTest extends GridCommonAbstractTest {
- /** Count of nodes. */
- private static final int NODE_COUNT = 4;
-
- /** Separator. */
- private static final String SEPARATOR = "\t";
-
- /** Grid instance. */
- protected Ignite ignite;
-
- /**
- * Default constructor.
- */
- public BaseKNNTest() {
- super(false);
- }
-
- /**
- * {@inheritDoc}
- */
- @Override protected void beforeTest() throws Exception {
- ignite = grid(NODE_COUNT);
- }
-
- /** {@inheritDoc} */
- @Override protected void beforeTestsStarted() throws Exception {
- for (int i = 1; i <= NODE_COUNT; i++)
- startGrid(i);
- }
-
- /** {@inheritDoc} */
- @Override protected void afterTestsStopped() throws Exception {
- stopAllGrids();
- }
-
- /**
- * Loads labeled dataset from file with .txt extension.
- *
- * @param rsrcPath path to dataset.
- * @return null if path is incorrect.
- */
- LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
- try {
- Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI());
- try {
- return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData);
- }
- catch (IOException e) {
- e.printStackTrace();
- }
- }
- catch (URISyntaxException e) {
- e.printStackTrace();
- return null;
- }
- return null;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index b27fcba..0877fc0 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -17,31 +17,35 @@
package org.apache.ignite.ml.knn;
-import org.apache.ignite.internal.util.IgniteUtils;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.Assert;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
+import org.junit.Test;
/** Tests behaviour of KNNClassificationTest. */
-public class KNNClassificationTest extends BaseKNNTest {
+public class KNNClassificationTest {
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+
/** */
- public void testBinaryClassificationTest() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ @Test
+ public void binaryClassificationTest() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {1.0, 1.0, 1.0});
- data.put(1, new double[] {1.0, 2.0, 1.0});
- data.put(2, new double[] {2.0, 1.0, 1.0});
- data.put(3, new double[] {-1.0, -1.0, 2.0});
- data.put(4, new double[] {-1.0, -2.0, 2.0});
- data.put(5, new double[] {-2.0, -1.0, 2.0});
+ data.put(0, new double[]{1.0, 1.0, 1.0});
+ data.put(1, new double[]{1.0, 2.0, 1.0});
+ data.put(2, new double[]{2.0, 1.0, 1.0});
+ data.put(3, new double[]{-1.0, -1.0, 2.0});
+ data.put(4, new double[]{-1.0, -2.0, 2.0});
+ data.put(5, new double[]{-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
@@ -54,23 +58,23 @@ public class KNNClassificationTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
- assertEquals(knnMdl.apply(firstVector), 1.0);
- Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
- assertEquals(knnMdl.apply(secondVector), 2.0);
+ Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0});
+ Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
+ Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0});
+ Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
}
/** */
- public void testBinaryClassificationWithSmallestKTest() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
+ @Test
+ public void binaryClassificationWithSmallestKTest() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {1.0, 1.0, 1.0});
- data.put(1, new double[] {1.0, 2.0, 1.0});
- data.put(2, new double[] {2.0, 1.0, 1.0});
- data.put(3, new double[] {-1.0, -1.0, 2.0});
- data.put(4, new double[] {-1.0, -2.0, 2.0});
- data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+ data.put(0, new double[]{1.0, 1.0, 1.0});
+ data.put(1, new double[]{1.0, 2.0, 1.0});
+ data.put(2, new double[]{2.0, 1.0, 1.0});
+ data.put(3, new double[]{-1.0, -1.0, 2.0});
+ data.put(4, new double[]{-1.0, -2.0, 2.0});
+ data.put(5, new double[]{-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
@@ -83,23 +87,23 @@ public class KNNClassificationTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
- assertEquals(knnMdl.apply(firstVector), 1.0);
- Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
- assertEquals(knnMdl.apply(secondVector), 2.0);
+ Vector firstVector = new DenseLocalOnHeapVector(new double[]{2.0, 2.0});
+ Assert.assertEquals(knnMdl.apply(firstVector), 1.0, PRECISION);
+ Vector secondVector = new DenseLocalOnHeapVector(new double[]{-2.0, -2.0});
+ Assert.assertEquals(knnMdl.apply(secondVector), 2.0, PRECISION);
}
/** */
- public void testBinaryClassificationFarPointsWithSimpleStrategy() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
+ @Test
+ public void binaryClassificationFarPointsWithSimpleStrategy() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {10.0, 10.0, 1.0});
- data.put(1, new double[] {10.0, 20.0, 1.0});
- data.put(2, new double[] {-1, -1, 1.0});
- data.put(3, new double[] {-2, -2, 2.0});
- data.put(4, new double[] {-1.0, -2.0, 2.0});
- data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+ data.put(0, new double[]{10.0, 10.0, 1.0});
+ data.put(1, new double[]{10.0, 20.0, 1.0});
+ data.put(2, new double[]{-1, -1, 1.0});
+ data.put(3, new double[]{-2, -2, 2.0});
+ data.put(4, new double[]{-1.0, -2.0, 2.0});
+ data.put(5, new double[]{-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
@@ -112,21 +116,21 @@ public class KNNClassificationTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
- assertEquals(knnMdl.apply(vector), 2.0);
+ Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
+ Assert.assertEquals(knnMdl.apply(vector), 2.0, PRECISION);
}
/** */
- public void testBinaryClassificationFarPointsWithWeightedStrategy() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
+ @Test
+ public void binaryClassificationFarPointsWithWeightedStrategy() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {10.0, 10.0, 1.0});
- data.put(1, new double[] {10.0, 20.0, 1.0});
- data.put(2, new double[] {-1, -1, 1.0});
- data.put(3, new double[] {-2, -2, 2.0});
- data.put(4, new double[] {-1.0, -2.0, 2.0});
- data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+ data.put(0, new double[]{10.0, 10.0, 1.0});
+ data.put(1, new double[]{10.0, 20.0, 1.0});
+ data.put(2, new double[]{-1, -1, 1.0});
+ data.put(3, new double[]{-2, -2, 2.0});
+ data.put(4, new double[]{-1.0, -2.0, 2.0});
+ data.put(5, new double[]{-2.0, -1.0, 2.0});
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
@@ -139,7 +143,7 @@ public class KNNClassificationTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.WEIGHTED);
- Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
- assertEquals(knnMdl.apply(vector), 1.0);
+ Vector vector = new DenseLocalOnHeapVector(new double[]{-1.01, -1.01});
+ Assert.assertEquals(knnMdl.apply(vector), 1.0, PRECISION);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index 66dbca9..ce9cae5 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -17,7 +17,6 @@
package org.apache.ignite.ml.knn;
-import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
@@ -30,28 +29,23 @@ import org.junit.Assert;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
+import org.junit.Test;
/**
* Tests for {@link KNNRegressionTrainer}.
*/
-public class KNNRegressionTest extends BaseKNNTest {
+public class KNNRegressionTest {
/** */
- private double[] y;
-
- /** */
- private double[][] x;
-
- /** */
- public void testSimpleRegressionWithOneNeighbour() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
+ @Test
+ public void simpleRegressionWithOneNeighbour() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {11.0, 0, 0, 0, 0, 0});
- data.put(1, new double[] {12.0, 2.0, 0, 0, 0, 0});
- data.put(2, new double[] {13.0, 0, 3.0, 0, 0, 0});
- data.put(3, new double[] {14.0, 0, 0, 4.0, 0, 0});
- data.put(4, new double[] {15.0, 0, 0, 0, 5.0, 0});
- data.put(5, new double[] {16.0, 0, 0, 0, 0, 6.0});
+
+ data.put(0, new double[]{11.0, 0, 0, 0, 0, 0});
+ data.put(1, new double[]{12.0, 2.0, 0, 0, 0, 0});
+ data.put(2, new double[]{13.0, 0, 3.0, 0, 0, 0});
+ data.put(3, new double[]{14.0, 0, 0, 4.0, 0, 0});
+ data.put(4, new double[]{15.0, 0, 0, 0, 5.0, 0});
+ data.put(5, new double[]{16.0, 0, 0, 0, 0, 6.0});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
@@ -63,32 +57,31 @@ public class KNNRegressionTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[] {0, 0, 0, 5.0, 0.0});
+ Vector vector = new DenseLocalOnHeapVector(new double[]{0, 0, 0, 5.0, 0.0});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
}
/** */
- public void testLongly() {
-
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
+ @Test
+ public void longly() {
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
- data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
- data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949});
- data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950});
- data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951});
- data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952});
- data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953});
- data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954});
- data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955});
- data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957});
- data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958});
- data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959});
- data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960});
- data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961});
- data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962});
+
+ data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947});
+ data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948});
+ data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949});
+ data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950});
+ data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951});
+ data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952});
+ data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953});
+ data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954});
+ data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955});
+ data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957});
+ data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958});
+ data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959});
+ data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960});
+ data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961});
+ data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
@@ -100,31 +93,30 @@ public class KNNRegressionTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
+ Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
/** */
public void testLonglyWithWeightedStrategy() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
Map<Integer, double[]> data = new HashMap<>();
- data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
- data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
- data.put(2, new double[] {60171, 88.2, 258054, 3682, 1616, 109773, 1949});
- data.put(3, new double[] {61187, 89.5, 284599, 3351, 1650, 110929, 1950});
- data.put(4, new double[] {63221, 96.2, 328975, 2099, 3099, 112075, 1951});
- data.put(5, new double[] {63639, 98.1, 346999, 1932, 3594, 113270, 1952});
- data.put(6, new double[] {64989, 99.0, 365385, 1870, 3547, 115094, 1953});
- data.put(7, new double[] {63761, 100.0, 363112, 3578, 3350, 116219, 1954});
- data.put(8, new double[] {66019, 101.2, 397469, 2904, 3048, 117388, 1955});
- data.put(9, new double[] {68169, 108.4, 442769, 2936, 2798, 120445, 1957});
- data.put(10, new double[] {66513, 110.8, 444546, 4681, 2637, 121950, 1958});
- data.put(11, new double[] {68655, 112.6, 482704, 3813, 2552, 123366, 1959});
- data.put(12, new double[] {69564, 114.2, 502601, 3931, 2514, 125368, 1960});
- data.put(13, new double[] {69331, 115.7, 518173, 4806, 2572, 127852, 1961});
- data.put(14, new double[] {70551, 116.9, 554894, 4007, 2827, 130081, 1962});
+
+ data.put(0, new double[]{60323, 83.0, 234289, 2356, 1590, 107608, 1947});
+ data.put(1, new double[]{61122, 88.5, 259426, 2325, 1456, 108632, 1948});
+ data.put(2, new double[]{60171, 88.2, 258054, 3682, 1616, 109773, 1949});
+ data.put(3, new double[]{61187, 89.5, 284599, 3351, 1650, 110929, 1950});
+ data.put(4, new double[]{63221, 96.2, 328975, 2099, 3099, 112075, 1951});
+ data.put(5, new double[]{63639, 98.1, 346999, 1932, 3594, 113270, 1952});
+ data.put(6, new double[]{64989, 99.0, 365385, 1870, 3547, 115094, 1953});
+ data.put(7, new double[]{63761, 100.0, 363112, 3578, 3350, 116219, 1954});
+ data.put(8, new double[]{66019, 101.2, 397469, 2904, 3048, 117388, 1955});
+ data.put(9, new double[]{68169, 108.4, 442769, 2936, 2798, 120445, 1957});
+ data.put(10, new double[]{66513, 110.8, 444546, 4681, 2637, 121950, 1958});
+ data.put(11, new double[]{68655, 112.6, 482704, 3813, 2552, 123366, 1959});
+ data.put(12, new double[]{69564, 114.2, 502601, 3931, 2514, 125368, 1960});
+ data.put(13, new double[]{69331, 115.7, 518173, 4806, 2572, 127852, 1961});
+ data.put(14, new double[]{70551, 116.9, 554894, 4007, 2827, 130081, 1962});
KNNRegressionTrainer trainer = new KNNRegressionTrainer();
@@ -136,7 +128,7 @@ public class KNNRegressionTest extends BaseKNNTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(KNNStrategy.SIMPLE);
- Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
+ Vector vector = new DenseLocalOnHeapVector(new double[]{104.6, 419180, 2822, 2857, 118734, 1956});
System.out.println(knnMdl.apply(vector));
Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
new file mode 100644
index 0000000..a25b303
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Base class for decision trees test.
+ */
+public class LabeledDatasetHelper extends GridCommonAbstractTest {
+ /** Count of nodes. */
+ private static final int NODE_COUNT = 4;
+
+ /** Separator. */
+ private static final String SEPARATOR = "\t";
+
+ /** Grid instance. */
+ protected Ignite ignite;
+
+ /**
+ * Default constructor.
+ */
+ public LabeledDatasetHelper() {
+ super(false);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ ignite = grid(NODE_COUNT);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() throws Exception {
+ stopAllGrids();
+ }
+
+ /**
+ * Loads labeled dataset from file with .txt extension.
+ *
+ * @param rsrcPath path to dataset.
+ * @return null if path is incorrect.
+ */
+ LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
+ try {
+ Path path = Paths.get(this.getClass().getClassLoader().getResource(rsrcPath).toURI());
+ try {
+ return LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, isFallOnBadData);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ } catch (URISyntaxException e) {
+ e.printStackTrace();
+ return null;
+ }
+ return null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/8550d61b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
index cdd5dc4..77d40a6 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
@@ -34,7 +34,7 @@ import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
/** Tests behaviour of KNNClassificationTest. */
-public class LabeledDatasetTest extends BaseKNNTest implements ExternalizableTest<LabeledDataset> {
+public class LabeledDatasetTest extends LabeledDatasetHelper implements ExternalizableTest<LabeledDataset> {
/** */
private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";