You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/04/02 09:26:47 UTC
[1/2] ignite git commit: IGNITE-7702: Adopt kNN classifcation to the
new datasets
Repository: ignite
Updated Branches:
refs/heads/master d8fc15bc2 -> 43d055767
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
deleted file mode 100644
index ba1b82a..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMPartitionDataBuilderOnHeap.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm;
-
-import java.io.Serializable;
-import java.util.Iterator;
-import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.UpstreamEntry;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledVector;
-
-/**
- * SVM partition data builder that builds {@link LabeledDataset}.
- *
- * @param <K> Type of a key in <tt>upstream</tt> data.
- * @param <V> Type of a value in <tt>upstream</tt> data.
- * @param <C> Type of a partition <tt>context</tt>.
- */
-public class SVMPartitionDataBuilderOnHeap<K, V, C extends Serializable>
- implements PartitionDataBuilder<K, V, C, LabeledDataset<Double, LabeledVector>> {
- /** */
- private static final long serialVersionUID = -7820760153954269227L;
-
- /** Extractor of X matrix row. */
- private final IgniteBiFunction<K, V, double[]> xExtractor;
-
- /** Extractor of Y vector value. */
- private final IgniteBiFunction<K, V, Double> yExtractor;
-
- /**
- * Constructs a new instance of SVM partition data builder.
- *
- * @param xExtractor Extractor of X matrix row.
- * @param yExtractor Extractor of Y vector value.
- */
- public SVMPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor,
- IgniteBiFunction<K, V, Double> yExtractor) {
- this.xExtractor = xExtractor;
- this.yExtractor = yExtractor;
- }
-
- /** {@inheritDoc} */
- @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
- long upstreamDataSize, C ctx) {
-
- int xCols = -1;
- double[][] x = null;
- double[] y = new double[Math.toIntExact(upstreamDataSize)];
-
- int ptr = 0;
-
- while (upstreamData.hasNext()) {
- UpstreamEntry<K, V> entry = upstreamData.next();
- double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
-
- if (xCols < 0) {
- xCols = row.length;
- x = new double[Math.toIntExact(upstreamDataSize)][xCols];
- }
- else
- assert row.length == xCols : "X extractor must return exactly " + xCols + " columns";
-
- x[ptr] = row;
-
- y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
-
- ptr++;
- }
-
- return new LabeledDataset<>(x, y);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionContext.java
deleted file mode 100644
index 6c5e3da..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionContext.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-import java.io.Serializable;
-
-/**
- * Partition context of the SVM classification algorithm.
- */
-public class LabelPartitionContext implements Serializable {
- /** */
- private static final long serialVersionUID = -7412302212344430126L;
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataBuilderOnHeap.java
deleted file mode 100644
index f44835c..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataBuilderOnHeap.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-import java.io.Serializable;
-import java.util.Iterator;
-import org.apache.ignite.ml.dataset.PartitionDataBuilder;
-import org.apache.ignite.ml.dataset.UpstreamEntry;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.structures.LabeledDataset;
-
-/**
- * SVM partition data builder that builds {@link LabeledDataset}.
- *
- * @param <K> Type of a key in <tt>upstream</tt> data.
- * @param <V> Type of a value in <tt>upstream</tt> data.
- * @param <C> Type of a partition <tt>context</tt>.
- */
-public class LabelPartitionDataBuilderOnHeap<K, V, C extends Serializable>
- implements PartitionDataBuilder<K, V, C, LabelPartitionDataOnHeap> {
- /** */
- private static final long serialVersionUID = -7820760153954269227L;
-
- /** Extractor of Y vector value. */
- private final IgniteBiFunction<K, V, Double> yExtractor;
-
- /**
- * Constructs a new instance of Label partition data builder.
- *
- * @param yExtractor Extractor of Y vector value.
- */
- public LabelPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, Double> yExtractor) {
- this.yExtractor = yExtractor;
- }
-
- /** {@inheritDoc} */
- @Override public LabelPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
- C ctx) {
- double[] y = new double[Math.toIntExact(upstreamDataSize)];
-
- int ptr = 0;
- while (upstreamData.hasNext()) {
- UpstreamEntry<K, V> entry = upstreamData.next();
-
- y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
-
- ptr++;
- }
- return new LabelPartitionDataOnHeap(y);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataOnHeap.java
deleted file mode 100644
index 0bbf566..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/LabelPartitionDataOnHeap.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.svm.multi;
-
-/**
- * On Heap partition data that keeps part of a labels.
- */
-public class LabelPartitionDataOnHeap implements AutoCloseable {
- /** Part of Y vector. */
- private final double[] y;
-
- /**
- * Constructs a new instance of linear system partition data.
- *
- * @param y Part of Y vector.
- */
- public LabelPartitionDataOnHeap(double[] y) {
- this.y = y;
- }
-
- /** */
- public double[] getY() {
- return y;
- }
-
- /** {@inheritDoc} */
- @Override public void close() {
- // Do nothing, GC will clean up.
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/package-info.java
deleted file mode 100644
index 28afdea..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/multi/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains API for Multi-SVM(support vector machines) algorithms.
- */
-package org.apache.ignite.ml.svm.multi;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index 236d7e5..3f12bdc 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -23,14 +23,13 @@ import java.nio.file.Path;
import java.util.function.Function;
import org.apache.ignite.ml.clustering.KMeansLocalClusterer;
import org.apache.ignite.ml.clustering.KMeansModel;
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNModelFormat;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
+import org.apache.ignite.ml.knn.classification.KNNModelFormat;
+import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.junit.Assert;
@@ -96,7 +95,6 @@ public class LocalModelsTest {
});
}
-
/** */
@Test
public void importExportSVMMulticlassClassificationModelTest() throws IOException {
@@ -156,19 +154,10 @@ public class LocalModelsTest {
@Test
public void importExportKNNModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- LabeledDataset training = new LabeledDataset(mtx, lbs);
-
- KNNModel mdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
+ KNNClassificationModel mdl = new KNNClassificationModel(null)
+ .withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(KNNStrategy.SIMPLE);
Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
@@ -177,7 +166,10 @@ public class LocalModelsTest {
Assert.assertNotNull(load);
- KNNModel importedMdl = new KNNModel(load.getK(), load.getDistanceMeasure(), load.getStgy(), load.getTraining());
+ KNNClassificationModel importedMdl = new KNNClassificationModel(null)
+ .withK(load.getK())
+ .withDistanceMeasure(load.getDistanceMeasure())
+ .withStrategy(load.getStgy());
Assert.assertTrue("", mdl.equals(importedMdl));
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
index 1651588..aeac2cf 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/BaseKNNTest.java
@@ -36,9 +36,6 @@ public class BaseKNNTest extends GridCommonAbstractTest {
/** Separator. */
private static final String SEPARATOR = "\t";
- /** Path to the Iris dataset. */
- static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
-
/** Grid instance. */
protected Ignite ignite;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index 28af6fa..b5a4b54 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -17,14 +17,17 @@
package org.apache.ignite.ml.knn;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
+import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.KNNStrategy;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.math.exceptions.knn.SmallTrainingDatasetSizeException;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
/** Tests behaviour of KNNClassificationTest. */
public class KNNClassificationTest extends BaseKNNTest {
@@ -32,19 +35,24 @@ public class KNNClassificationTest extends BaseKNNTest {
public void testBinaryClassificationTest() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+ Map<Integer, double[]> data = new HashMap<>();
+ data.put(0, new double[] {1.0, 1.0, 1.0});
+ data.put(1, new double[] {1.0, 2.0, 1.0});
+ data.put(2, new double[] {2.0, 1.0, 1.0});
+ data.put(3, new double[] {-1.0, -1.0, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
- LabeledDataset training = new LabeledDataset(mtx, lbs);
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+ KNNClassificationModel knnMdl = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+ (k, v) -> v[2]
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(KNNStrategy.SIMPLE);
- KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
assertEquals(knnMdl.apply(firstVector), 1.0);
Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
@@ -55,19 +63,24 @@ public class KNNClassificationTest extends BaseKNNTest {
public void testBinaryClassificationWithSmallestKTest() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+ Map<Integer, double[]> data = new HashMap<>();
+ data.put(0, new double[] {1.0, 1.0, 1.0});
+ data.put(1, new double[] {1.0, 2.0, 1.0});
+ data.put(2, new double[] {2.0, 1.0, 1.0});
+ data.put(3, new double[] {-1.0, -1.0, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- LabeledDataset training = new LabeledDataset(mtx, lbs);
+ KNNClassificationModel knnMdl = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+ (k, v) -> v[2]
+ ).withK(1)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(KNNStrategy.SIMPLE);
- KNNModel knnMdl = new KNNModel(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
Vector firstVector = new DenseLocalOnHeapVector(new double[] {2.0, 2.0});
assertEquals(knnMdl.apply(firstVector), 1.0);
Vector secondVector = new DenseLocalOnHeapVector(new double[] {-2.0, -2.0});
@@ -78,18 +91,24 @@ public class KNNClassificationTest extends BaseKNNTest {
public void testBinaryClassificationFarPointsWithSimpleStrategy() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- double[][] mtx =
- new double[][] {
- {10.0, 10.0},
- {10.0, 20.0},
- {-1, -1},
- {-2, -2},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
- LabeledDataset training = new LabeledDataset(mtx, lbs);
-
- KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
+ Map<Integer, double[]> data = new HashMap<>();
+ data.put(0, new double[] {10.0, 10.0, 1.0});
+ data.put(1, new double[] {10.0, 20.0, 1.0});
+ data.put(2, new double[] {-1, -1, 1.0});
+ data.put(3, new double[] {-2, -2, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
+
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+
+ KNNClassificationModel knnMdl = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+ (k, v) -> v[2]
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(KNNStrategy.SIMPLE);
+
Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
assertEquals(knnMdl.apply(vector), 2.0);
}
@@ -98,56 +117,25 @@ public class KNNClassificationTest extends BaseKNNTest {
public void testBinaryClassificationFarPointsWithWeightedStrategy() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- double[][] mtx =
- new double[][] {
- {10.0, 10.0},
- {10.0, 20.0},
- {-1, -1},
- {-2, -2},
- {-1.0, -2.0},
- {-2.0, -1.0}
- };
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
- LabeledDataset training = new LabeledDataset(mtx, lbs);
-
- KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.WEIGHTED, training);
- Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
- assertEquals(knnMdl.apply(vector), 1.0);
- }
+ Map<Integer, double[]> data = new HashMap<>();
+ data.put(0, new double[] {10.0, 10.0, 1.0});
+ data.put(1, new double[] {10.0, 20.0, 1.0});
+ data.put(2, new double[] {-1, -1, 1.0});
+ data.put(3, new double[] {-2, -2, 2.0});
+ data.put(4, new double[] {-1.0, -2.0, 2.0});
+ data.put(5, new double[] {-2.0, -1.0, 2.0});
- /** */
- public void testPredictOnIrisDataset() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
- LabeledDataset training = loadDatasetFromTxt(KNN_IRIS_TXT, false);
-
- KNNModel knnMdl = new KNNModel(7, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
- Vector vector = new DenseLocalOnHeapVector(new double[] {5.15, 3.55, 1.45, 0.25});
- assertEquals(knnMdl.apply(vector), 1.0);
- }
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- /** */
- public void testLargeKValue() {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ KNNClassificationModel knnMdl = trainer.fit(
+ new LocalDatasetBuilder<>(data, 2),
+ (k, v) -> Arrays.copyOfRange(v, 0, v.length - 1),
+ (k, v) -> v[2]
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(KNNStrategy.WEIGHTED);
- double[][] mtx =
- new double[][] {
- {10.0, 10.0},
- {10.0, 20.0},
- {-1, -1},
- {-2, -2},
- {-1.0, -2.0},
- {-2.0, -1.0}
- };
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
- LabeledDataset training = new LabeledDataset(mtx, lbs);
-
- try {
- new KNNModel(7, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
- fail("SmallTrainingDatasetSizeException");
- }
- catch (SmallTrainingDatasetSizeException e) {
- return;
- }
- fail("SmallTrainingDatasetSizeException");
+ Vector vector = new DenseLocalOnHeapVector(new double[] {-1.01, -1.01});
+ assertEquals(knnMdl.apply(vector), 1.0);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
deleted file mode 100644
index e5d9b13..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNMultipleLinearRegressionTest.java
+++ /dev/null
@@ -1,157 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn;
-
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.math.impls.vector.SparseBlockDistributedVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.preprocessing.Normalizer;
-import org.junit.Assert;
-
-/**
- * Tests for {@link KNNMultipleLinearRegression}.
- */
-public class KNNMultipleLinearRegressionTest extends BaseKNNTest {
- /** */
- private double[] y;
-
- /** */
- private double[][] x;
-
- /** */
- public void testSimpleRegressionWithOneNeighbour() {
-
- y = new double[] {11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
- x = new double[6][];
- x[0] = new double[] {0, 0, 0, 0, 0};
- x[1] = new double[] {2.0, 0, 0, 0, 0};
- x[2] = new double[] {0, 3.0, 0, 0, 0};
- x[3] = new double[] {0, 0, 4.0, 0, 0};
- x[4] = new double[] {0, 0, 0, 5.0, 0};
- x[5] = new double[] {0, 0, 0, 0, 6.0};
-
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset training = new LabeledDataset(x, y);
-
- KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
- Vector vector = new SparseBlockDistributedVector(new double[] {0, 0, 0, 5.0, 0.0});
- System.out.println(knnMdl.apply(vector));
- Assert.assertEquals(15, knnMdl.apply(vector), 1E-12);
- }
-
- /** */
- public void testLongly() {
-
- y = new double[] {60323, 61122, 60171, 61187, 63221, 63639, 64989, 63761, 66019, 68169, 66513, 68655, 69564, 69331, 70551};
- x = new double[15][];
- x[0] = new double[] {83.0, 234289, 2356, 1590, 107608, 1947};
- x[1] = new double[] {88.5, 259426, 2325, 1456, 108632, 1948};
- x[2] = new double[] {88.2, 258054, 3682, 1616, 109773, 1949};
- x[3] = new double[] {89.5, 284599, 3351, 1650, 110929, 1950};
- x[4] = new double[] {96.2, 328975, 2099, 3099, 112075, 1951};
- x[5] = new double[] {98.1, 346999, 1932, 3594, 113270, 1952};
- x[6] = new double[] {99.0, 365385, 1870, 3547, 115094, 1953};
- x[7] = new double[] {100.0, 363112, 3578, 3350, 116219, 1954};
- x[8] = new double[] {101.2, 397469, 2904, 3048, 117388, 1955};
- x[9] = new double[] {108.4, 442769, 2936, 2798, 120445, 1957};
- x[10] = new double[] {110.8, 444546, 4681, 2637, 121950, 1958};
- x[11] = new double[] {112.6, 482704, 3813, 2552, 123366, 1959};
- x[12] = new double[] {114.2, 502601, 3931, 2514, 125368, 1960};
- x[13] = new double[] {115.7, 518173, 4806, 2572, 127852, 1961};
- x[14] = new double[] {116.9, 554894, 4007, 2827, 130081, 1962};
-
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset training = new LabeledDataset(x, y);
-
- KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
- Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.apply(vector));
- Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
- }
-
- /** */
- public void testLonglyWithNormalization() {
- y = new double[] {60323, 61122, 60171, 61187, 63221, 63639, 64989, 63761, 66019, 68169, 66513, 68655, 69564, 69331, 70551};
- x = new double[15][];
- x[0] = new double[] {83.0, 234289, 2356, 1590, 107608, 1947};
- x[1] = new double[] {88.5, 259426, 2325, 1456, 108632, 1948};
- x[2] = new double[] {88.2, 258054, 3682, 1616, 109773, 1949};
- x[3] = new double[] {89.5, 284599, 3351, 1650, 110929, 1950};
- x[4] = new double[] {96.2, 328975, 2099, 3099, 112075, 1951};
- x[5] = new double[] {98.1, 346999, 1932, 3594, 113270, 1952};
- x[6] = new double[] {99.0, 365385, 1870, 3547, 115094, 1953};
- x[7] = new double[] {100.0, 363112, 3578, 3350, 116219, 1954};
- x[8] = new double[] {101.2, 397469, 2904, 3048, 117388, 1955};
- x[9] = new double[] {108.4, 442769, 2936, 2798, 120445, 1957};
- x[10] = new double[] {110.8, 444546, 4681, 2637, 121950, 1958};
- x[11] = new double[] {112.6, 482704, 3813, 2552, 123366, 1959};
- x[12] = new double[] {114.2, 502601, 3931, 2514, 125368, 1960};
- x[13] = new double[] {115.7, 518173, 4806, 2572, 127852, 1961};
- x[14] = new double[] {116.9, 554894, 4007, 2827, 130081, 1962};
-
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset training = new LabeledDataset(x, y);
-
- final LabeledDataset normalizedTrainingDataset = (LabeledDataset)Normalizer.normalizeWithMiniMax(training);
-
- KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.SIMPLE, normalizedTrainingDataset);
- Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.apply(vector));
- Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
- }
-
- /** */
- public void testLonglyWithWeightedStrategyAndNormalization() {
- y = new double[] {60323, 61122, 60171, 61187, 63221, 63639, 64989, 63761, 66019, 68169, 66513, 68655, 69564, 69331, 70551};
- x = new double[15][];
- x[0] = new double[] {83.0, 234289, 2356, 1590, 107608, 1947};
- x[1] = new double[] {88.5, 259426, 2325, 1456, 108632, 1948};
- x[2] = new double[] {88.2, 258054, 3682, 1616, 109773, 1949};
- x[3] = new double[] {89.5, 284599, 3351, 1650, 110929, 1950};
- x[4] = new double[] {96.2, 328975, 2099, 3099, 112075, 1951};
- x[5] = new double[] {98.1, 346999, 1932, 3594, 113270, 1952};
- x[6] = new double[] {99.0, 365385, 1870, 3547, 115094, 1953};
- x[7] = new double[] {100.0, 363112, 3578, 3350, 116219, 1954};
- x[8] = new double[] {101.2, 397469, 2904, 3048, 117388, 1955};
- x[9] = new double[] {108.4, 442769, 2936, 2798, 120445, 1957};
- x[10] = new double[] {110.8, 444546, 4681, 2637, 121950, 1958};
- x[11] = new double[] {112.6, 482704, 3813, 2552, 123366, 1959};
- x[12] = new double[] {114.2, 502601, 3931, 2514, 125368, 1960};
- x[13] = new double[] {115.7, 518173, 4806, 2572, 127852, 1961};
- x[14] = new double[] {116.9, 554894, 4007, 2827, 130081, 1962};
-
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- LabeledDataset training = new LabeledDataset(x, y);
-
- final LabeledDataset normalizedTrainingDataset = (LabeledDataset)Normalizer.normalizeWithMiniMax(training);
-
- KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(5, new EuclideanDistance(), KNNStrategy.WEIGHTED, normalizedTrainingDataset);
- Vector vector = new DenseLocalOnHeapVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
- System.out.println(knnMdl.apply(vector));
- Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
index 8b47e0a..95ebec5 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
@@ -26,7 +26,6 @@ import org.junit.runners.Suite;
@RunWith(Suite.class)
@Suite.SuiteClasses({
KNNClassificationTest.class,
- KNNMultipleLinearRegressionTest.class,
LabeledDatasetTest.class
})
public class KNNTestSuite {
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
index 079ae55..cdd5dc4 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
@@ -50,7 +50,6 @@ public class LabeledDatasetTest extends BaseKNNTest implements ExternalizableTes
/** */
private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt";
-
/** */
public void testFeatureNames() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/Datasets.java
----------------------------------------------------------------------
diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/Datasets.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/Datasets.java
deleted file mode 100644
index c0191e0..0000000
--- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/Datasets.java
+++ /dev/null
@@ -1,453 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.yardstick.ml.knn;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledVector;
-
-/**
- * Datasets used in KMeansDistributedClustererExample and in KMeansLocalClustererExample.
- */
-class Datasets {
- /**
- * Generate dataset shuffled as defined in parameter.
- *
- * @param off Parameter to use for shuffling raw data.
- * @return Generated dataset.
- */
- LabeledDataset shuffleIris(int off) {
- return shuffle(off, vectorsIris, labelsIris);
- }
-
- /**
- * Generate dataset shuffled as defined in parameter.
- *
- * @param off Parameter to use for shuffling raw data.
- * @return Generated dataset.
- */
- LabeledDataset shuffleClearedMachines(int off) {
- return shuffle(off, vectorsClearedMachines, labelsClearedMachines);
- }
-
- /** */
- private LabeledDataset shuffle(int off, List<Vector> vectors, List<Double> labels) {
- int size = vectors.size();
-
- LabeledVector[] data = new LabeledVector[size];
- for (int i = 0; i < vectors.size(); i++)
- data[(i + off) % (size - 1)] = new LabeledVector<>(vectors.get(i), labels.get(i));
-
- return new LabeledDataset(data, vectors.get(0).size());
- }
-
- /** */
- private final static double[][] dataIris = {
- new double[] {1.0, 5.1, 3.5, 1.4, 0.2},
- new double[] {1.0, 4.9, 3.0, 1.4, 0.2},
- new double[] {1.0, 4.7, 3.2, 1.3, 0.2},
- new double[] {1.0, 4.6, 3.1, 1.5, 0.2},
- new double[] {1.0, 5.0, 3.6, 1.4, 0.2},
- new double[] {1.0, 5.4, 3.9, 1.7, 0.4},
- new double[] {1.0, 4.6, 3.4, 1.4, 0.3},
- new double[] {1.0, 5.0, 3.4, 1.5, 0.2},
- new double[] {1.0, 4.4, 2.9, 1.4, 0.2},
- new double[] {1.0, 4.9, 3.1, 1.5, 0.1},
- new double[] {1.0, 5.4, 3.7, 1.5, 0.2},
- new double[] {1.0, 4.8, 3.4, 1.6, 0.2},
- new double[] {1.0, 4.8, 3.0, 1.4, 0.1},
- new double[] {1.0, 4.3, 3.0, 1.1, 0.1},
- new double[] {1.0, 5.8, 4.0, 1.2, 0.2},
- new double[] {1.0, 5.7, 4.4, 1.5, 0.4},
- new double[] {1.0, 5.4, 3.9, 1.3, 0.4},
- new double[] {1.0, 5.1, 3.5, 1.4, 0.3},
- new double[] {1.0, 5.7, 3.8, 1.7, 0.3},
- new double[] {1.0, 5.1, 3.8, 1.5, 0.3},
- new double[] {1.0, 5.4, 3.4, 1.7, 0.2},
- new double[] {1.0, 5.1, 3.7, 1.5, 0.4},
- new double[] {1.0, 4.6, 3.6, 1.0, 0.2},
- new double[] {1.0, 5.1, 3.3, 1.7, 0.5},
- new double[] {1.0, 4.8, 3.4, 1.9, 0.2},
- new double[] {1.0, 5.0, 3.0, 1.6, 0.2},
- new double[] {1.0, 5.0, 3.4, 1.6, 0.4},
- new double[] {1.0, 5.2, 3.5, 1.5, 0.2},
- new double[] {1.0, 5.2, 3.4, 1.4, 0.2},
- new double[] {1.0, 4.7, 3.2, 1.6, 0.2},
- new double[] {1.0, 4.8, 3.1, 1.6, 0.2},
- new double[] {1.0, 5.4, 3.4, 1.5, 0.4},
- new double[] {1.0, 5.2, 4.1, 1.5, 0.1},
- new double[] {1.0, 5.5, 4.2, 1.4, 0.2},
- new double[] {1.0, 4.9, 3.1, 1.5, 0.1},
- new double[] {1.0, 5.0, 3.2, 1.2, 0.2},
- new double[] {1.0, 5.5, 3.5, 1.3, 0.2},
- new double[] {1.0, 4.9, 3.1, 1.5, 0.1},
- new double[] {1.0, 4.4, 3.0, 1.3, 0.2},
- new double[] {1.0, 5.1, 3.4, 1.5, 0.2},
- new double[] {1.0, 5.0, 3.5, 1.3, 0.3},
- new double[] {1.0, 4.5, 2.3, 1.3, 0.3},
- new double[] {1.0, 4.4, 3.2, 1.3, 0.2},
- new double[] {1.0, 5.0, 3.5, 1.6, 0.6},
- new double[] {1.0, 5.1, 3.8, 1.9, 0.4},
- new double[] {1.0, 4.8, 3.0, 1.4, 0.3},
- new double[] {1.0, 5.1, 3.8, 1.6, 0.2},
- new double[] {1.0, 4.6, 3.2, 1.4, 0.2},
- new double[] {1.0, 5.3, 3.7, 1.5, 0.2},
- new double[] {1.0, 5.0, 3.3, 1.4, 0.2},
- new double[] {2.0, 7.0, 3.2, 4.7, 1.4},
- new double[] {2.0, 6.4, 3.2, 4.5, 1.5},
- new double[] {2.0, 6.9, 3.1, 4.9, 1.5},
- new double[] {2.0, 5.5, 2.3, 4.0, 1.3},
- new double[] {2.0, 6.5, 2.8, 4.6, 1.5},
- new double[] {2.0, 5.7, 2.8, 4.5, 1.3},
- new double[] {2.0, 6.3, 3.3, 4.7, 1.6},
- new double[] {2.0, 4.9, 2.4, 3.3, 1.0},
- new double[] {2.0, 6.6, 2.9, 4.6, 1.3},
- new double[] {2.0, 5.2, 2.7, 3.9, 1.4},
- new double[] {2.0, 5.0, 2.0, 3.5, 1.0},
- new double[] {2.0, 5.9, 3.0, 4.2, 1.5},
- new double[] {2.0, 6.0, 2.2, 4.0, 1.0},
- new double[] {2.0, 6.1, 2.9, 4.7, 1.4},
- new double[] {2.0, 5.6, 2.9, 3.6, 1.3},
- new double[] {2.0, 6.7, 3.1, 4.4, 1.4},
- new double[] {2.0, 5.6, 3.0, 4.5, 1.5},
- new double[] {2.0, 5.8, 2.7, 4.1, 1.0},
- new double[] {2.0, 6.2, 2.2, 4.5, 1.5},
- new double[] {2.0, 5.6, 2.5, 3.9, 1.1},
- new double[] {2.0, 5.9, 3.2, 4.8, 1.8},
- new double[] {2.0, 6.1, 2.8, 4.0, 1.3},
- new double[] {2.0, 6.3, 2.5, 4.9, 1.5},
- new double[] {2.0, 6.1, 2.8, 4.7, 1.2},
- new double[] {2.0, 6.4, 2.9, 4.3, 1.3},
- new double[] {2.0, 6.6, 3.0, 4.4, 1.4},
- new double[] {2.0, 6.8, 2.8, 4.8, 1.4},
- new double[] {2.0, 6.7, 3.0, 5.0, 1.7},
- new double[] {2.0, 6.0, 2.9, 4.5, 1.5},
- new double[] {2.0, 5.7, 2.6, 3.5, 1.0},
- new double[] {2.0, 5.5, 2.4, 3.8, 1.1},
- new double[] {2.0, 5.5, 2.4, 3.7, 1.0},
- new double[] {2.0, 5.8, 2.7, 3.9, 1.2},
- new double[] {2.0, 6.0, 2.7, 5.1, 1.6},
- new double[] {2.0, 5.4, 3.0, 4.5, 1.5},
- new double[] {2.0, 6.0, 3.4, 4.5, 1.6},
- new double[] {2.0, 6.7, 3.1, 4.7, 1.5},
- new double[] {2.0, 6.3, 2.3, 4.4, 1.3},
- new double[] {2.0, 5.6, 3.0, 4.1, 1.3},
- new double[] {2.0, 5.5, 2.5, 4.0, 1.3},
- new double[] {2.0, 5.5, 2.6, 4.4, 1.2},
- new double[] {2.0, 6.1, 3.0, 4.6, 1.4},
- new double[] {2.0, 5.8, 2.6, 4.0, 1.2},
- new double[] {2.0, 5.0, 2.3, 3.3, 1.0},
- new double[] {2.0, 5.6, 2.7, 4.2, 1.3},
- new double[] {2.0, 5.7, 3.0, 4.2, 1.2},
- new double[] {2.0, 5.7, 2.9, 4.2, 1.3},
- new double[] {2.0, 6.2, 2.9, 4.3, 1.3},
- new double[] {2.0, 5.1, 2.5, 3.0, 1.1},
- new double[] {2.0, 5.7, 2.8, 4.1, 1.3},
- new double[] {3.0, 6.3, 3.3, 6.0, 2.5},
- new double[] {3.0, 5.8, 2.7, 5.1, 1.9},
- new double[] {3.0, 7.1, 3.0, 5.9, 2.1},
- new double[] {3.0, 6.3, 2.9, 5.6, 1.8},
- new double[] {3.0, 6.5, 3.0, 5.8, 2.2},
- new double[] {3.0, 7.6, 3.0, 6.6, 2.1},
- new double[] {3.0, 4.9, 2.5, 4.5, 1.7},
- new double[] {3.0, 7.3, 2.9, 6.3, 1.8},
- new double[] {3.0, 6.7, 2.5, 5.8, 1.8},
- new double[] {3.0, 7.2, 3.6, 6.1, 2.5},
- new double[] {3.0, 6.5, 3.2, 5.1, 2.0},
- new double[] {3.0, 6.4, 2.7, 5.3, 1.9},
- new double[] {3.0, 6.8, 3.0, 5.5, 2.1},
- new double[] {3.0, 5.7, 2.5, 5.0, 2.0},
- new double[] {3.0, 5.8, 2.8, 5.1, 2.4},
- new double[] {3.0, 6.4, 3.2, 5.3, 2.3},
- new double[] {3.0, 6.5, 3.0, 5.5, 1.8},
- new double[] {3.0, 7.7, 3.8, 6.7, 2.2},
- new double[] {3.0, 7.7, 2.6, 6.9, 2.3},
- new double[] {3.0, 6.0, 2.2, 5.0, 1.5},
- new double[] {3.0, 6.9, 3.2, 5.7, 2.3},
- new double[] {3.0, 5.6, 2.8, 4.9, 2.0},
- new double[] {3.0, 7.7, 2.8, 6.7, 2.0},
- new double[] {3.0, 6.3, 2.7, 4.9, 1.8},
- new double[] {3.0, 6.7, 3.3, 5.7, 2.1},
- new double[] {3.0, 7.2, 3.2, 6.0, 1.8},
- new double[] {3.0, 6.2, 2.8, 4.8, 1.8},
- new double[] {3.0, 6.1, 3.0, 4.9, 1.8},
- new double[] {3.0, 6.4, 2.8, 5.6, 2.1},
- new double[] {3.0, 7.2, 3.0, 5.8, 1.6},
- new double[] {3.0, 7.4, 2.8, 6.1, 1.9},
- new double[] {3.0, 7.9, 3.8, 6.4, 2.0},
- new double[] {3.0, 6.4, 2.8, 5.6, 2.2},
- new double[] {3.0, 6.3, 2.8, 5.1, 1.5},
- new double[] {3.0, 6.1, 2.6, 5.6, 1.4},
- new double[] {3.0, 7.7, 3.0, 6.1, 2.3},
- new double[] {3.0, 6.3, 3.4, 5.6, 2.4},
- new double[] {3.0, 6.4, 3.1, 5.5, 1.8},
- new double[] {3.0, 6.0, 3.0, 4.8, 1.8},
- new double[] {3.0, 6.9, 3.1, 5.4, 2.1},
- new double[] {3.0, 6.7, 3.1, 5.6, 2.4},
- new double[] {3.0, 6.9, 3.1, 5.1, 2.3},
- new double[] {3.0, 5.8, 2.7, 5.1, 1.9},
- new double[] {3.0, 6.8, 3.2, 5.9, 2.3},
- new double[] {3.0, 6.7, 3.3, 5.7, 2.5},
- new double[] {3.0, 6.7, 3.0, 5.2, 2.3},
- new double[] {3.0, 6.3, 2.5, 5.0, 1.9},
- new double[] {3.0, 6.5, 3.0, 5.2, 2.0},
- new double[] {3.0, 6.2, 3.4, 5.4, 2.3},
- new double[] {3.0, 5.9, 3.0, 5.1, 1.8},
- };
-
- /** */
- private static final List<Double> labelsIris = new ArrayList<>();
-
- /** */
- private static final List<Vector> vectorsIris = new ArrayList<>();
-
- /** */
- private final static double[][] dataClearedMachines = {
- new double[] {199,125,256,6000,256,16,128},
- new double[] {253,29,8000,32000,32,8,32},
- new double[] {253,29,8000,32000,32,8,32},
- new double[] {253,29,8000,32000,32,8,32},
- new double[] {132,29,8000,16000,32,8,16},
- new double[] {290,26,8000,32000,64,8,32},
- new double[] {381,23,16000,32000,64,16,32},
- new double[] {381,23,16000,32000,64,16,32},
- new double[] {749,23,16000,64000,64,16,32},
- new double[] {1238,23,32000,64000,128,32,64},
- new double[] {23,400,1000,3000,0,1,2},
- new double[] {24,400,512,3500,4,1,6},
- new double[] {70,60,2000,8000,65,1,8},
- new double[] {117,50,4000,16000,65,1,8},
- new double[] {15,350,64,64,0,1,4},
- new double[] {64,200,512,16000,0,4,32},
- new double[] {23,167,524,2000,8,4,15},
- new double[] {29,143,512,5000,0,7,32},
- new double[] {22,143,1000,2000,0,5,16},
- new double[] {124,110,5000,5000,142,8,64},
- new double[] {35,143,1500,6300,0,5,32},
- new double[] {39,143,3100,6200,0,5,20},
- new double[] {40,143,2300,6200,0,6,64},
- new double[] {45,110,3100,6200,0,6,64},
- new double[] {28,320,128,6000,0,1,12},
- new double[] {21,320,512,2000,4,1,3},
- new double[] {28,320,256,6000,0,1,6},
- new double[] {22,320,256,3000,4,1,3},
- new double[] {28,320,512,5000,4,1,5},
- new double[] {27,320,256,5000,4,1,6},
- new double[] {102,25,1310,2620,131,12,24},
- new double[] {102,25,1310,2620,131,12,24},
- new double[] {74,50,2620,10480,30,12,24},
- new double[] {74,50,2620,10480,30,12,24},
- new double[] {138,56,5240,20970,30,12,24},
- new double[] {136,64,5240,20970,30,12,24},
- new double[] {23,50,500,2000,8,1,4},
- new double[] {29,50,1000,4000,8,1,5},
- new double[] {44,50,2000,8000,8,1,5},
- new double[] {30,50,1000,4000,8,3,5},
- new double[] {41,50,1000,8000,8,3,5},
- new double[] {74,50,2000,16000,8,3,5},
- new double[] {74,50,2000,16000,8,3,6},
- new double[] {74,50,2000,16000,8,3,6},
- new double[] {54,133,1000,12000,9,3,12},
- new double[] {41,133,1000,8000,9,3,12},
- new double[] {18,810,512,512,8,1,1},
- new double[] {28,810,1000,5000,0,1,1},
- new double[] {36,320,512,8000,4,1,5},
- new double[] {38,200,512,8000,8,1,8},
- new double[] {34,700,384,8000,0,1,1},
- new double[] {19,700,256,2000,0,1,1},
- new double[] {72,140,1000,16000,16,1,3},
- new double[] {36,200,1000,8000,0,1,2},
- new double[] {30,110,1000,4000,16,1,2},
- new double[] {56,110,1000,12000,16,1,2},
- new double[] {42,220,1000,8000,16,1,2},
- new double[] {34,800,256,8000,0,1,4},
- new double[] {34,800,256,8000,0,1,4},
- new double[] {34,800,256,8000,0,1,4},
- new double[] {34,800,256,8000,0,1,4},
- new double[] {34,800,256,8000,0,1,4},
- new double[] {19,125,512,1000,0,8,20},
- new double[] {75,75,2000,8000,64,1,38},
- new double[] {113,75,2000,16000,64,1,38},
- new double[] {157,75,2000,16000,128,1,38},
- new double[] {18,90,256,1000,0,3,10},
- new double[] {20,105,256,2000,0,3,10},
- new double[] {28,105,1000,4000,0,3,24},
- new double[] {33,105,2000,4000,8,3,19},
- new double[] {47,75,2000,8000,8,3,24},
- new double[] {54,75,3000,8000,8,3,48},
- new double[] {20,175,256,2000,0,3,24},
- new double[] {23,300,768,3000,0,6,24},
- new double[] {25,300,768,3000,6,6,24},
- new double[] {52,300,768,12000,6,6,24},
- new double[] {27,300,768,4500,0,1,24},
- new double[] {50,300,384,12000,6,1,24},
- new double[] {18,300,192,768,6,6,24},
- new double[] {53,180,768,12000,6,1,31},
- new double[] {23,330,1000,3000,0,2,4},
- new double[] {30,300,1000,4000,8,3,64},
- new double[] {73,300,1000,16000,8,2,112},
- new double[] {20,330,1000,2000,0,1,2},
- new double[] {25,330,1000,4000,0,3,6},
- new double[] {28,140,2000,4000,0,3,6},
- new double[] {29,140,2000,4000,0,4,8},
- new double[] {32,140,2000,4000,8,1,20},
- new double[] {175,140,2000,32000,32,1,20},
- new double[] {57,140,2000,8000,32,1,54},
- new double[] {181,140,2000,32000,32,1,54},
- new double[] {181,140,2000,32000,32,1,54},
- new double[] {32,140,2000,4000,8,1,20},
- new double[] {82,57,4000,16000,1,6,12},
- new double[] {171,57,4000,24000,64,12,16},
- new double[] {361,26,16000,32000,64,16,24},
- new double[] {350,26,16000,32000,64,8,24},
- new double[] {220,26,8000,32000,0,8,24},
- new double[] {113,26,8000,16000,0,8,16},
- new double[] {15,480,96,512,0,1,1},
- new double[] {21,203,1000,2000,0,1,5},
- new double[] {35,115,512,6000,16,1,6},
- new double[] {18,1100,512,1500,0,1,1},
- new double[] {20,1100,768,2000,0,1,1},
- new double[] {20,600,768,2000,0,1,1},
- new double[] {28,400,2000,4000,0,1,1},
- new double[] {45,400,4000,8000,0,1,1},
- new double[] {18,900,1000,1000,0,1,2},
- new double[] {17,900,512,1000,0,1,2},
- new double[] {26,900,1000,4000,4,1,2},
- new double[] {28,900,1000,4000,8,1,2},
- new double[] {28,900,2000,4000,0,3,6},
- new double[] {31,225,2000,4000,8,3,6},
- new double[] {31,225,2000,4000,8,3,6},
- new double[] {42,180,2000,8000,8,1,6},
- new double[] {76,185,2000,16000,16,1,6},
- new double[] {76,180,2000,16000,16,1,6},
- new double[] {26,225,1000,4000,2,3,6},
- new double[] {59,25,2000,12000,8,1,4},
- new double[] {65,25,2000,12000,16,3,5},
- new double[] {101,17,4000,16000,8,6,12},
- new double[] {116,17,4000,16000,32,6,12},
- new double[] {18,1500,768,1000,0,0,0},
- new double[] {20,1500,768,2000,0,0,0},
- new double[] {20,800,768,2000,0,0,0},
- new double[] {30,50,2000,4000,0,3,6},
- new double[] {44,50,2000,8000,8,3,6},
- new double[] {44,50,2000,8000,8,1,6},
- new double[] {82,50,2000,16000,24,1,6},
- new double[] {82,50,2000,16000,24,1,6},
- new double[] {128,50,8000,16000,48,1,10},
- new double[] {37,100,1000,8000,0,2,6},
- new double[] {46,100,1000,8000,24,2,6},
- new double[] {46,100,1000,8000,24,3,6},
- new double[] {80,50,2000,16000,12,3,16},
- new double[] {88,50,2000,16000,24,6,16},
- new double[] {88,50,2000,16000,24,6,16},
- new double[] {33,150,512,4000,0,8,128},
- new double[] {46,115,2000,8000,16,1,3},
- new double[] {29,115,2000,4000,2,1,5},
- new double[] {53,92,2000,8000,32,1,6},
- new double[] {53,92,2000,8000,32,1,6},
- new double[] {41,92,2000,8000,4,1,6},
- new double[] {86,75,4000,16000,16,1,6},
- new double[] {95,60,4000,16000,32,1,6},
- new double[] {107,60,2000,16000,64,5,8},
- new double[] {117,60,4000,16000,64,5,8},
- new double[] {119,50,4000,16000,64,5,10},
- new double[] {120,72,4000,16000,64,8,16},
- new double[] {48,72,2000,8000,16,6,8},
- new double[] {126,40,8000,16000,32,8,16},
- new double[] {266,40,8000,32000,64,8,24},
- new double[] {270,35,8000,32000,64,8,24},
- new double[] {426,38,16000,32000,128,16,32},
- new double[] {151,48,4000,24000,32,8,24},
- new double[] {267,38,8000,32000,64,8,24},
- new double[] {603,30,16000,32000,256,16,24},
- new double[] {19,112,1000,1000,0,1,4},
- new double[] {21,84,1000,2000,0,1,6},
- new double[] {26,56,1000,4000,0,1,6},
- new double[] {35,56,2000,6000,0,1,8},
- new double[] {41,56,2000,8000,0,1,8},
- new double[] {47,56,4000,8000,0,1,8},
- new double[] {62,56,4000,12000,0,1,8},
- new double[] {78,56,4000,16000,0,1,8},
- new double[] {80,38,4000,8000,32,16,32},
- new double[] {80,38,4000,8000,32,16,32},
- new double[] {142,38,8000,16000,64,4,8},
- new double[] {281,38,8000,24000,160,4,8},
- new double[] {190,38,4000,16000,128,16,32},
- new double[] {21,200,1000,2000,0,1,2},
- new double[] {25,200,1000,4000,0,1,4},
- new double[] {67,200,2000,8000,64,1,5},
- new double[] {24,250,512,4000,0,1,7},
- new double[] {24,250,512,4000,0,4,7},
- new double[] {64,250,1000,16000,1,1,8},
- new double[] {25,160,512,4000,2,1,5},
- new double[] {20,160,512,2000,2,3,8},
- new double[] {29,160,1000,4000,8,1,14},
- new double[] {43,160,1000,8000,16,1,14},
- new double[] {53,160,2000,8000,32,1,13},
- new double[] {19,240,512,1000,8,1,3},
- new double[] {22,240,512,2000,8,1,5},
- new double[] {31,105,2000,4000,8,3,8},
- new double[] {41,105,2000,6000,16,6,16},
- new double[] {47,105,2000,8000,16,4,14},
- new double[] {99,52,4000,16000,32,4,12},
- new double[] {67,70,4000,12000,8,6,8},
- new double[] {81,59,4000,12000,32,6,12},
- new double[] {149,59,8000,16000,64,12,24},
- new double[] {183,26,8000,24000,32,8,16},
- new double[] {275,26,8000,32000,64,12,16},
- new double[] {382,26,8000,32000,128,24,32},
- new double[] {56,116,2000,8000,32,5,28},
- new double[] {182,50,2000,32000,24,6,26},
- new double[] {227,50,2000,32000,48,26,52},
- new double[] {341,50,2000,32000,112,52,104},
- new double[] {360,50,4000,32000,112,52,104},
- new double[] {919,30,8000,64000,96,12,176},
- new double[] {978,30,8000,64000,128,12,176},
- new double[] {24,180,262,4000,0,1,3},
- new double[] {24,180,512,4000,0,1,3},
- new double[] {24,180,262,4000,0,1,3},
- new double[] {24,180,512,4000,0,1,3},
- new double[] {37,124,1000,8000,0,1,8},
- new double[] {50,98,1000,8000,32,2,8},
- new double[] {41,125,2000,8000,0,2,14},
- new double[] {47,480,512,8000,32,0,0},
- new double[] {25,480,1000,4000,0,0,0},
- };
-
- /** */
- private static final List<Double> labelsClearedMachines = new ArrayList<>();
-
- /** */
- private static final List<Vector> vectorsClearedMachines = new ArrayList<>();
-
- static {
- Arrays.stream(dataIris).forEach(e -> {
- labelsIris.add(e[0]);
- vectorsIris.add(new DenseLocalOnHeapVector(new double[] {e[1], e[2], e[3], e[4]}));
- });
-
- Arrays.stream(dataClearedMachines).forEach(e -> {
- labelsClearedMachines.add(e[0]);
- vectorsClearedMachines.add(new DenseLocalOnHeapVector(new double[] {e[1], e[2], e[3], e[4]}));
- });
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNClassificationBenchmark.java
----------------------------------------------------------------------
diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNClassificationBenchmark.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNClassificationBenchmark.java
deleted file mode 100644
index 53c73cf..0000000
--- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNClassificationBenchmark.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.yardstick.ml.knn;
-
-import java.util.Map;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.resources.IgniteInstanceResource;
-import org.apache.ignite.thread.IgniteThread;
-import org.apache.ignite.yardstick.IgniteAbstractBenchmark;
-import org.apache.ignite.yardstick.ml.DataChanger;
-
-/**
- * Ignite benchmark that performs ML Grid operations.
- */
-@SuppressWarnings("unused")
-public class IgniteKNNClassificationBenchmark extends IgniteAbstractBenchmark {
- /** */
- @IgniteInstanceResource
- private Ignite ignite;
-
- /** {@inheritDoc} */
- @Override public boolean test(Map<Object, Object> ctx) throws Exception {
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- this.getClass().getSimpleName(), new Runnable() {
- /** {@inheritDoc} */
- @Override public void run() {
- // IMPL NOTE originally taken from KNNClassificationExample.
- // Obtain shuffled dataset.
- LabeledDataset dataset = new Datasets().shuffleIris((int)(DataChanger.next()));
-
- // Random splitting of iris data as 70% train and 30% test datasets.
- LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.3);
-
- LabeledDataset test = split.test();
- LabeledDataset train = split.train();
-
- KNNModel knnMdl = new KNNModel(5, new EuclideanDistance(), KNNStrategy.SIMPLE, train);
-
- // Calculate predicted classes.
- for (int i = 0; i < test.rowSize() - 1; i++)
- knnMdl.apply(test.getRow(i).features());
- }
- });
-
- igniteThread.start();
-
- igniteThread.join();
-
- return true;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNRegressionBenchmark.java
----------------------------------------------------------------------
diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNRegressionBenchmark.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNRegressionBenchmark.java
deleted file mode 100644
index 75242e6..0000000
--- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/IgniteKNNRegressionBenchmark.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.yardstick.ml.knn;
-
-import java.util.Map;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression;
-import org.apache.ignite.ml.math.distances.ManhattanDistance;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.ml.structures.preprocessing.Normalizer;
-import org.apache.ignite.resources.IgniteInstanceResource;
-import org.apache.ignite.thread.IgniteThread;
-import org.apache.ignite.yardstick.IgniteAbstractBenchmark;
-import org.apache.ignite.yardstick.ml.DataChanger;
-
-/**
- * Ignite benchmark that performs ML Grid operations.
- */
-@SuppressWarnings("unused")
-public class IgniteKNNRegressionBenchmark extends IgniteAbstractBenchmark {
- /** */
- @IgniteInstanceResource
- private Ignite ignite;
-
- /** {@inheritDoc} */
- @Override public boolean test(Map<Object, Object> ctx) throws Exception {
- // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
- // because we create ignite cache internally.
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- this.getClass().getSimpleName(), new Runnable() {
- /** {@inheritDoc} */
- @Override public void run() {
- // IMPL NOTE originally taken from KNNRegressionExample.
- // Obtain shuffled dataset.
- LabeledDataset dataset = new Datasets().shuffleClearedMachines((int)(DataChanger.next()));
-
- // Normalize dataset
- Normalizer.normalizeWithMiniMax(dataset);
-
- // Random splitting of iris data as 80% train and 20% test datasets.
- LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.2);
-
- LabeledDataset test = split.test();
- LabeledDataset train = split.train();
-
- // Builds weighted kNN-regression with Manhattan Distance.
- KNNModel knnMdl = new KNNMultipleLinearRegression(7, new ManhattanDistance(), KNNStrategy.WEIGHTED, train);
-
- // Clone labels
- final double[] labels = test.labels();
-
- // Calculate predicted classes.
- for (int i = 0; i < test.rowSize() - 1; i++)
- knnMdl.apply(test.getRow(i).features());
- }
- });
-
- igniteThread.start();
-
- igniteThread.join();
-
- return true;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/package-info.java
----------------------------------------------------------------------
diff --git a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/package-info.java b/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/package-info.java
deleted file mode 100644
index a5ff26a..0000000
--- a/modules/yardstick/src/main/java/org/apache/ignite/yardstick/ml/knn/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * ML Grid kNN benchmarks.
- */
-package org.apache.ignite.yardstick.ml.knn;
\ No newline at end of file
[2/2] ignite git commit: IGNITE-7702: Adopt kNN classifcation to the
new datasets
Posted by ch...@apache.org.
IGNITE-7702: Adopt kNN classifcation to the new datasets
this closes #3565
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/43d05576
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/43d05576
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/43d05576
Branch: refs/heads/master
Commit: 43d055767099b89ca09b0bc63beb20728e93735a
Parents: d8fc15b
Author: zaleslaw <za...@gmail.com>
Authored: Mon Apr 2 12:26:38 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Mon Apr 2 12:26:38 2018 +0300
----------------------------------------------------------------------
.../KNNClassificationExample.java | 154 -------
.../ml/knn/classification/package-info.java | 22 -
.../ignite/examples/ml/knn/package-info.java | 22 -
.../ml/knn/regression/KNNRegressionExample.java | 155 -------
.../ml/knn/regression/package-info.java | 22 -
.../classification/KNNClassificationModel.java | 260 +++++++++++
.../KNNClassificationTrainer.java | 60 +++
.../ml/knn/classification/KNNModelFormat.java | 88 ++++
.../ml/knn/classification/KNNStrategy.java | 29 ++
.../ml/knn/classification/package-info.java | 22 +
.../apache/ignite/ml/knn/models/KNNModel.java | 233 ----------
.../ignite/ml/knn/models/KNNModelFormat.java | 96 ----
.../ignite/ml/knn/models/KNNStrategy.java | 27 --
.../ignite/ml/knn/models/package-info.java | 22 -
.../ml/knn/partitions/KNNPartitionContext.java | 28 ++
.../ignite/ml/knn/partitions/package-info.java | 22 +
.../regression/KNNMultipleLinearRegression.java | 83 ----
.../ignite/ml/knn/regression/package-info.java | 22 -
.../ml/math/distances/DistanceMeasure.java | 12 +
.../ml/math/distances/EuclideanDistance.java | 10 +
.../ml/math/distances/HammingDistance.java | 5 +
.../ml/math/distances/ManhattanDistance.java | 5 +
.../ignite/ml/math/isolve/lsqr/LSQROnHeap.java | 2 +-
.../partition/LabelPartitionContext.java | 28 ++
.../LabelPartitionDataBuilderOnHeap.java | 66 +++
.../partition/LabelPartitionDataOnHeap.java | 45 ++
...abeledDatasetPartitionDataBuilderOnHeap.java | 86 ++++
.../SVMLinearBinaryClassificationTrainer.java | 3 +-
...VMLinearMultiClassClassificationTrainer.java | 6 +-
.../ml/svm/SVMPartitionDataBuilderOnHeap.java | 88 ----
.../ml/svm/multi/LabelPartitionContext.java | 28 --
.../multi/LabelPartitionDataBuilderOnHeap.java | 66 ---
.../ml/svm/multi/LabelPartitionDataOnHeap.java | 45 --
.../ignite/ml/svm/multi/package-info.java | 22 -
.../org/apache/ignite/ml/LocalModelsTest.java | 30 +-
.../org/apache/ignite/ml/knn/BaseKNNTest.java | 3 -
.../ignite/ml/knn/KNNClassificationTest.java | 160 +++----
.../ml/knn/KNNMultipleLinearRegressionTest.java | 157 -------
.../org/apache/ignite/ml/knn/KNNTestSuite.java | 1 -
.../ignite/ml/knn/LabeledDatasetTest.java | 1 -
.../ignite/yardstick/ml/knn/Datasets.java | 453 -------------------
.../knn/IgniteKNNClassificationBenchmark.java | 73 ---
.../ml/knn/IgniteKNNRegressionBenchmark.java | 82 ----
.../ignite/yardstick/ml/knn/package-info.java | 22 -
44 files changed, 857 insertions(+), 2009 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java
deleted file mode 100644
index 6532ac5..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/KNNClassificationExample.java
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.knn.classification;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.file.Path;
-import java.util.Arrays;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.apache.ignite.ml.structures.preprocessing.LabellingMachine;
-import org.apache.ignite.thread.IgniteThread;
-
-/**
- * <p>
- * Example of using {@link KNNModel} with iris dataset.</p>
- * <p>
- * Note that in this example we cannot guarantee order in which nodes return results of intermediate
- * computations and therefore algorithm can return different results.</p>
- * <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
- * <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
- */
-public class KNNClassificationExample {
- /** Separator. */
- private static final String SEPARATOR = "\t";
-
- /** Path to the Iris dataset. */
- private static final String KNN_IRIS_TXT = "examples/src/main/resources/datasets/iris.txt";
-
- /**
- * Executes example.
- *
- * @param args Command line arguments, none required.
- */
- public static void main(String[] args) throws InterruptedException {
- System.out.println(">>> kNN classification example started.");
- // Start ignite grid.
- try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
-
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- KNNClassificationExample.class.getSimpleName(), () -> {
-
- try {
- // Prepare path to read
- File file = IgniteUtils.resolveIgnitePath(KNN_IRIS_TXT);
- if (file == null)
- throw new RuntimeException("Can't find file: " + KNN_IRIS_TXT);
-
- Path path = file.toPath();
-
- // Read dataset from file
- LabeledDataset dataset = LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, true, false);
-
- // Random splitting of iris data as 70% train and 30% test datasets
- LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.3);
-
- System.out.println("\n>>> Amount of observations in train dataset " + split.train().rowSize());
- System.out.println("\n>>> Amount of observations in test dataset " + split.test().rowSize());
-
- LabeledDataset test = split.test();
- LabeledDataset train = split.train();
-
- KNNModel knnMdl = new KNNModel(5, new EuclideanDistance(), KNNStrategy.SIMPLE, train);
-
- // Clone labels
- final double[] labels = test.labels();
-
- // Save predicted classes to test dataset
- LabellingMachine.assignLabels(test, knnMdl);
-
- // Calculate amount of errors on test dataset
- int amountOfErrors = 0;
- for (int i = 0; i < test.rowSize(); i++) {
- if (test.label(i) != labels[i])
- amountOfErrors++;
- }
-
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + amountOfErrors / (double)test.rowSize());
-
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
- for (int i = 0; i < test.rowSize(); i++) {
- int idx1 = (int)test.label(i);
- int idx2 = (int)labels[i];
- confusionMtx[idx1 - 1][idx2 - 1]++;
- }
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-
- // Calculate precision, recall and F-metric for each class
- for (int i = 0; i < 3; i++) {
- double precision = 0.0;
- for (int j = 0; j < 3; j++)
- precision += confusionMtx[i][j];
- precision = confusionMtx[i][i] / precision;
-
- double clsLb = (double)(i + 1);
-
- System.out.println("\n>>> Precision for class " + clsLb + " is " + precision);
-
- double recall = 0.0;
- for (int j = 0; j < 3; j++)
- recall += confusionMtx[j][i];
- recall = confusionMtx[i][i] / recall;
- System.out.println("\n>>> Recall for class " + clsLb + " is " + recall);
-
- double fScore = 2 * precision * recall / (precision + recall);
- System.out.println("\n>>> F-score for class " + clsLb + " is " + fScore);
- }
-
- }
- catch (IOException e) {
- e.printStackTrace();
- System.out.println("\n>>> Unexpected exception, check resources: " + e);
- }
- finally {
- System.out.println("\n>>> kNN classification example completed.");
- }
-
- });
-
- igniteThread.start();
- igniteThread.join();
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/package-info.java
deleted file mode 100644
index d853f0d..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/classification/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * kNN classification examples.
- */
-package org.apache.ignite.examples.ml.knn.classification;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/examples/src/main/java/org/apache/ignite/examples/ml/knn/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/package-info.java
deleted file mode 100644
index 8de4656..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * kNN examples.
- */
-package org.apache.ignite.examples.ml.knn;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java
deleted file mode 100644
index ba079cc..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/KNNRegressionExample.java
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.knn.regression;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.file.Path;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression;
-import org.apache.ignite.ml.math.distances.ManhattanDistance;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.apache.ignite.ml.structures.preprocessing.LabellingMachine;
-import org.apache.ignite.ml.structures.preprocessing.Normalizer;
-import org.apache.ignite.thread.IgniteThread;
-
-/**
- * <p>
- * Example of using {@link KNNMultipleLinearRegression} with iris dataset.</p>
- * <p>
- * Note that in this example we cannot guarantee order in which nodes return results of intermediate
- * computations and therefore algorithm can return different results.</p>
- * <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
- * <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
- */
-public class KNNRegressionExample {
- /** Separator. */
- private static final String SEPARATOR = ",";
-
- /** */
- private static final String KNN_CLEARED_MACHINES_TXT = "examples/src/main/resources/datasets/cleared_machines.txt";
-
- /**
- * Executes example.
- *
- * @param args Command line arguments, none required.
- */
- public static void main(String[] args) throws InterruptedException {
- System.out.println(">>> kNN regression example started.");
- // Start ignite grid.
- try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
-
- IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
- KNNRegressionExample.class.getSimpleName(), () -> {
-
- try {
- // Prepare path to read
- File file = IgniteUtils.resolveIgnitePath(KNN_CLEARED_MACHINES_TXT);
- if (file == null)
- throw new RuntimeException("Can't find file: " + KNN_CLEARED_MACHINES_TXT);
-
- Path path = file.toPath();
-
- // Read dataset from file
- LabeledDataset dataset = LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, false, false);
-
- // Normalize dataset
- Normalizer.normalizeWithMiniMax(dataset);
-
- // Random splitting of iris data as 80% train and 20% test datasets
- LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.2);
-
- System.out.println("\n>>> Amount of observations in train dataset: " + split.train().rowSize());
- System.out.println("\n>>> Amount of observations in test dataset: " + split.test().rowSize());
-
- LabeledDataset test = split.test();
- LabeledDataset train = split.train();
-
- // Builds weighted kNN-regression with Manhattan Distance
- KNNMultipleLinearRegression knnMdl = new KNNMultipleLinearRegression(7, new ManhattanDistance(),
- KNNStrategy.WEIGHTED, train);
-
- // Clone labels
- final double[] labels = test.labels();
-
- // Save predicted classes to test dataset
- LabellingMachine.assignLabels(test, knnMdl);
-
- // Calculate mean squared error (MSE)
- double mse = 0.0;
-
- for (int i = 0; i < test.rowSize(); i++)
- mse += Math.pow(test.label(i) - labels[i], 2.0);
- mse = mse / test.rowSize();
-
- System.out.println("\n>>> Mean squared error (MSE) " + mse);
-
- // Calculate mean absolute error (MAE)
- double mae = 0.0;
-
- for (int i = 0; i < test.rowSize(); i++)
- mae += Math.abs(test.label(i) - labels[i]);
- mae = mae / test.rowSize();
-
- System.out.println("\n>>> Mean absolute error (MAE) " + mae);
-
- // Calculate R^2 as 1 - RSS/TSS
- double avg = 0.0;
-
- for (int i = 0; i < test.rowSize(); i++)
- avg += test.label(i);
-
- avg = avg / test.rowSize();
-
- double detCf = 0.0;
- double tss = 0.0;
-
- for (int i = 0; i < test.rowSize(); i++) {
- detCf += Math.pow(test.label(i) - labels[i], 2.0);
- tss += Math.pow(test.label(i) - avg, 2.0);
- }
-
- detCf = 1 - detCf / tss;
-
- System.out.println("\n>>> R^2 " + detCf);
- }
- catch (IOException e) {
- e.printStackTrace();
- System.out.println("\n>>> Unexpected exception, check resources: " + e);
- }
- finally {
- System.out.println("\n>>> kNN regression example completed.");
- }
- });
-
- igniteThread.start();
- igniteThread.join();
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/package-info.java
deleted file mode 100644
index e7ac336..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/regression/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * kNN regression examples.
- */
-package org.apache.ignite.examples.ml.knn.regression;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
new file mode 100644
index 0000000..373f822
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationModel.java
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.classification;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.knn.partitions.KNNPartitionContext;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * kNN algorithm model to solve multi-class classification task.
+ */
+public class KNNClassificationModel<K, V> implements Model<Vector, Double>, Exportable<KNNModelFormat> {
+ /** Amount of nearest neighbors. */
+ protected int k = 5;
+
+ /** Distance measure. */
+ protected DistanceMeasure distanceMeasure = new EuclideanDistance();
+
+ /** kNN strategy. */
+ protected KNNStrategy stgy = KNNStrategy.SIMPLE;
+
+ /** Dataset. */
+ private Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset;
+
+ /**
+ * Builds the model via prepared dataset.
+ * @param dataset Specially prepared object to run algorithm over it.
+ */
+ public KNNClassificationModel(Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset) {
+ this.dataset = dataset;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector v) {
+ if(dataset != null) {
+ List<LabeledVector> neighbors = findKNearestNeighbors(v);
+
+ return classify(neighbors, v, stgy);
+ } else
+ throw new IllegalStateException("The train kNN dataset is null");
+ }
+
+ /** */
+ @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
+ KNNModelFormat mdlData = new KNNModelFormat(k, distanceMeasure, stgy);
+ exporter.save(mdlData, path);
+ }
+
+ /**
+ * Set up parameter of the kNN model.
+ * @param k Amount of nearest neighbors.
+ * @return Model.
+ */
+ public KNNClassificationModel<K, V> withK(int k) {
+ this.k = k;
+ return this;
+ }
+
+ /**
+ * Set up parameter of the kNN model.
+ * @param stgy Strategy of calculations.
+ * @return Model.
+ */
+ public KNNClassificationModel<K, V> withStrategy(KNNStrategy stgy) {
+ this.stgy = stgy;
+ return this;
+ }
+
+ /**
+ * Set up parameter of the kNN model.
+ * @param distanceMeasure Distance measure.
+ * @return Model.
+ */
+ public KNNClassificationModel<K, V> withDistanceMeasure(DistanceMeasure distanceMeasure) {
+ this.distanceMeasure = distanceMeasure;
+ return this;
+ }
+
+ /**
+ * The main idea is calculation all distance pairs between given vector and all vectors in training set, sorting
+ * them and finding k vectors with min distance with the given vector.
+ *
+ * @param v The given vector.
+ * @return K-nearest neighbors.
+ */
+ protected List<LabeledVector> findKNearestNeighbors(Vector v) {
+ List<LabeledVector> neighborsFromPartitions = dataset.compute(data -> {
+ TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, data);
+ return Arrays.asList(getKClosestVectors(data, distanceIdxPairs));
+ }, (a, b) -> a == null ? b : Stream.concat(a.stream(), b.stream()).collect(Collectors.toList()));
+
+ LabeledDataset<Double, LabeledVector> neighborsToFilter = buildLabeledDatasetOnListOfVectors(neighborsFromPartitions);
+
+ return Arrays.asList(getKClosestVectors(neighborsToFilter, getDistances(v, neighborsToFilter)));
+ }
+
+
+ /** */
+ private LabeledDataset<Double, LabeledVector> buildLabeledDatasetOnListOfVectors(
+ List<LabeledVector> neighborsFromPartitions) {
+ LabeledVector[] arr = new LabeledVector[neighborsFromPartitions.size()];
+ for (int i = 0; i < arr.length; i++)
+ arr[i] = neighborsFromPartitions.get(i);
+
+ return new LabeledDataset<Double, LabeledVector>(arr);
+ }
+
+ /**
+ * Iterates along entries in distance map and fill the resulting k-element array.
+ *
+ * @param trainingData The training data.
+ * @param distanceIdxPairs The distance map.
+ * @return K-nearest neighbors.
+ */
+ @NotNull private LabeledVector[] getKClosestVectors(LabeledDataset<Double, LabeledVector> trainingData,
+ TreeMap<Double, Set<Integer>> distanceIdxPairs) {
+ LabeledVector[] res = new LabeledVector[k];
+ int i = 0;
+ final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
+ while (i < k) {
+ double key = iter.next();
+ Set<Integer> idxs = distanceIdxPairs.get(key);
+ for (Integer idx : idxs) {
+ res[i] = trainingData.getRow(idx);
+ i++;
+ if (i >= k)
+ break; // go to next while-loop iteration
+ }
+ }
+ return res;
+ }
+
+ /**
+ * Computes distances between given vector and each vector in training dataset.
+ *
+ * @param v The given vector.
+ * @param trainingData The training dataset.
+ * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
+ * with Set because there can be a few vectors with the same distance.
+ */
+ @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledDataset<Double, LabeledVector> trainingData) {
+ TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
+
+ for (int i = 0; i < trainingData.rowSize(); i++) {
+
+ LabeledVector labeledVector = trainingData.getRow(i);
+ if (labeledVector != null) {
+ double distance = distanceMeasure.compute(v, labeledVector.features());
+ putDistanceIdxPair(distanceIdxPairs, i, distance);
+ }
+ }
+ return distanceIdxPairs;
+ }
+
+ /** */
+ private void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
+ if (distanceIdxPairs.containsKey(distance)) {
+ Set<Integer> idxs = distanceIdxPairs.get(distance);
+ idxs.add(i);
+ }
+ else {
+ Set<Integer> idxs = new HashSet<>();
+ idxs.add(i);
+ distanceIdxPairs.put(distance, idxs);
+ }
+ }
+
+ /** */
+ private double classify(List<LabeledVector> neighbors, Vector v, KNNStrategy stgy) {
+ Map<Double, Double> clsVotes = new HashMap<>();
+
+ for (LabeledVector neighbor : neighbors) {
+ double clsLb = (double)neighbor.label();
+
+ double distance = distanceMeasure.compute(v, neighbor.features());
+
+ if (clsVotes.containsKey(clsLb)) {
+ double clsVote = clsVotes.get(clsLb);
+ clsVote += getClassVoteForVector(stgy, distance);
+ clsVotes.put(clsLb, clsVote);
+ }
+ else {
+ final double val = getClassVoteForVector(stgy, distance);
+ clsVotes.put(clsLb, val);
+ }
+ }
+ return getClassWithMaxVotes(clsVotes);
+ }
+
+ /** */
+ private double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
+ return Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
+ }
+
+ /** */
+ private double getClassVoteForVector(KNNStrategy stgy, double distance) {
+ if (stgy.equals(KNNStrategy.WEIGHTED))
+ return 1 / distance; // strategy.WEIGHTED
+ else
+ return 1.0; // strategy.SIMPLE
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 1;
+
+ res = res * 37 + k;
+ res = res * 37 + distanceMeasure.hashCode();
+ res = res * 37 + stgy.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+
+ if (obj == null || getClass() != obj.getClass())
+ return false;
+
+ KNNClassificationModel that = (KNNClassificationModel)obj;
+
+ return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
new file mode 100644
index 0000000..357047f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.classification;
+
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.knn.partitions.KNNPartitionContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * kNN algorithm trainer to solve multi-class classification task.
+ */
+public class KNNClassificationTrainer implements SingleLabelDatasetTrainer<KNNClassificationModel> {
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param datasetBuilder Dataset builder.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @return Model.
+ */
+ @Override public <K, V> KNNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ PartitionDataBuilder<K, V, KNNPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder
+ = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+ featureExtractor,
+ lbExtractor
+ );
+
+ Dataset<KNNPartitionContext, LabeledDataset<Double, LabeledVector>> dataset = null;
+
+ if (datasetBuilder != null) {
+ dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new KNNPartitionContext(),
+ partDataBuilder
+ );
+ }
+ return new KNNClassificationModel<>(dataset);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNModelFormat.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNModelFormat.java
new file mode 100644
index 0000000..a2efe7f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNModelFormat.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.classification;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
+
+/**
+ * kNN model representation.
+ *
+ * @see KNNClassificationModel
+ */
+public class KNNModelFormat implements Serializable {
+ /** Amount of nearest neighbors. */
+ private int k;
+
+ /** Distance measure. */
+ private DistanceMeasure distanceMeasure;
+
+ /** kNN strategy. */
+ private KNNStrategy stgy;
+
+ /** Gets amount of nearest neighbors.*/
+ public int getK() {
+ return k;
+ }
+
+ /** Gets distance measure. */
+ public DistanceMeasure getDistanceMeasure() {
+ return distanceMeasure;
+ }
+
+ /** Gets kNN strategy.*/
+ public KNNStrategy getStgy() {
+ return stgy;
+ }
+
+ /**
+ * Creates an instance.
+ * @param k Amount of nearest neighbors.
+ * @param measure Distance measure.
+ * @param stgy kNN strategy.
+ */
+ public KNNModelFormat(int k, DistanceMeasure measure, KNNStrategy stgy) {
+ this.k = k;
+ this.distanceMeasure = measure;
+ this.stgy = stgy;
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = 1;
+
+ res = res * 37 + k;
+ res = res * 37 + distanceMeasure.hashCode();
+ res = res * 37 + stgy.hashCode();
+
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+
+ if (obj == null || getClass() != obj.getClass())
+ return false;
+
+ KNNModelFormat that = (KNNModelFormat)obj;
+
+ return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNStrategy.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNStrategy.java
new file mode 100644
index 0000000..9a117de
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNStrategy.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.classification;
+
+/** This enum contains settings for kNN algorithm. */
+public enum KNNStrategy {
+ /** The default strategy. All k neighbors have the same weight which is independent
+ * on their distance to the query point.*/
+ SIMPLE,
+
+ /** A refinement of the k-NN classification algorithm is to weigh the contribution of each of the k neighbors
+ * according to their distance to the query point, giving greater weight to closer neighbors. */
+ WEIGHTED
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/package-info.java
new file mode 100644
index 0000000..81b36e4
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains main APIs for kNN classification algorithms.
+ */
+package org.apache.ignite.ml.knn.classification;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java
deleted file mode 100644
index 3951be4..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModel.java
+++ /dev/null
@@ -1,233 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn.models;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Set;
-import java.util.TreeMap;
-import org.apache.ignite.ml.Exportable;
-import org.apache.ignite.ml.Exporter;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.distances.DistanceMeasure;
-import org.apache.ignite.ml.math.exceptions.knn.SmallTrainingDatasetSizeException;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * kNN algorithm is a classification algorithm.
- */
-public class KNNModel implements Model<Vector, Double>, Exportable<KNNModelFormat> {
- /** Amount of nearest neighbors. */
- protected final int k;
-
- /** Distance measure. */
- protected final DistanceMeasure distanceMeasure;
-
- /** Training dataset. */
- protected final LabeledDataset training;
-
- /** kNN strategy. */
- protected final KNNStrategy stgy;
-
- /** Cached distances for k-nearest neighbors. */
- protected double[] cachedDistances;
-
- /**
- * Creates the kNN model with the given parameters.
- *
- * @param k Amount of nearest neighbors.
- * @param distanceMeasure Distance measure.
- * @param stgy Strategy of calculations.
- * @param training Training dataset.
- */
- public KNNModel(int k, DistanceMeasure distanceMeasure, KNNStrategy stgy, LabeledDataset training) {
- assert training != null;
-
- if (training.rowSize() < k)
- throw new SmallTrainingDatasetSizeException(k, training.rowSize());
-
- this.k = k;
- this.distanceMeasure = distanceMeasure;
- this.training = training;
- this.stgy = stgy;
- }
-
- /** {@inheritDoc} */
- @Override public Double apply(Vector v) {
- LabeledVector[] neighbors = findKNearestNeighbors(v, true);
-
- return classify(neighbors, v, stgy);
- }
-
- /** */
- @Override public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
- KNNModelFormat mdlData = new KNNModelFormat(k, distanceMeasure, training, stgy);
-
- exporter.save(mdlData, path);
- }
-
- /**
- * The main idea is calculation all distance pairs between given vector and all vectors in training set, sorting
- * them and finding k vectors with min distance with the given vector.
- *
- * @param v The given vector.
- * @return K-nearest neighbors.
- */
- protected LabeledVector[] findKNearestNeighbors(Vector v, boolean isCashedDistance) {
- LabeledVector[] trainingData = (LabeledVector[])training.data();
-
- TreeMap<Double, Set<Integer>> distanceIdxPairs = getDistances(v, trainingData);
-
- return getKClosestVectors(trainingData, distanceIdxPairs, isCashedDistance);
- }
-
- /**
- * Iterates along entries in distance map and fill the resulting k-element array.
- *
- * @param trainingData The training data.
- * @param distanceIdxPairs The distance map.
- * @param isCashedDistances Cache distances if true.
- * @return K-nearest neighbors.
- */
- @NotNull private LabeledVector[] getKClosestVectors(LabeledVector[] trainingData,
- TreeMap<Double, Set<Integer>> distanceIdxPairs, boolean isCashedDistances) {
- LabeledVector[] res = new LabeledVector[k];
- int i = 0;
- final Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
- while (i < k) {
- double key = iter.next();
- Set<Integer> idxs = distanceIdxPairs.get(key);
- for (Integer idx : idxs) {
- res[i] = trainingData[idx];
- if (isCashedDistances) {
- if (cachedDistances == null)
- cachedDistances = new double[k];
- cachedDistances[i] = key;
- }
- i++;
- if (i >= k)
- break; // go to next while-loop iteration
- }
- }
- return res;
- }
-
- /**
- * Computes distances between given vector and each vector in training dataset.
- *
- * @param v The given vector.
- * @param trainingData The training dataset.
- * @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
- * with Set because there can be a few vectors with the same distance.
- */
- @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledVector[] trainingData) {
- TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
-
- for (int i = 0; i < trainingData.length; i++) {
-
- LabeledVector labeledVector = trainingData[i];
- if (labeledVector != null) {
- double distance = distanceMeasure.compute(v, labeledVector.features());
- putDistanceIdxPair(distanceIdxPairs, i, distance);
- }
- }
- return distanceIdxPairs;
- }
-
- /** */
- private void putDistanceIdxPair(Map<Double, Set<Integer>> distanceIdxPairs, int i, double distance) {
- if (distanceIdxPairs.containsKey(distance)) {
- Set<Integer> idxs = distanceIdxPairs.get(distance);
- idxs.add(i);
- }
- else {
- Set<Integer> idxs = new HashSet<>();
- idxs.add(i);
- distanceIdxPairs.put(distance, idxs);
- }
- }
-
- /** */
- private double classify(LabeledVector[] neighbors, Vector v, KNNStrategy stgy) {
- Map<Double, Double> clsVotes = new HashMap<>();
-
- for (int i = 0; i < neighbors.length; i++) {
- LabeledVector neighbor = neighbors[i];
- double clsLb = (double)neighbor.label();
-
- double distance = cachedDistances != null ? cachedDistances[i] : distanceMeasure.compute(v, neighbor.features());
-
- if (clsVotes.containsKey(clsLb)) {
- double clsVote = clsVotes.get(clsLb);
- clsVote += getClassVoteForVector(stgy, distance);
- clsVotes.put(clsLb, clsVote);
- }
- else {
- final double val = getClassVoteForVector(stgy, distance);
- clsVotes.put(clsLb, val);
- }
- }
- return getClassWithMaxVotes(clsVotes);
- }
-
- /** */
- private double getClassWithMaxVotes(Map<Double, Double> clsVotes) {
- return Collections.max(clsVotes.entrySet(), Map.Entry.comparingByValue()).getKey();
- }
-
- /** */
- private double getClassVoteForVector(KNNStrategy stgy, double distance) {
- if (stgy.equals(KNNStrategy.WEIGHTED))
- return 1 / distance; // strategy.WEIGHTED
- else
- return 1.0; // strategy.SIMPLE
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = 1;
-
- res = res * 37 + k;
- res = res * 37 + distanceMeasure.hashCode();
- res = res * 37 + stgy.hashCode();
- res = res * 37 + Arrays.hashCode(training.data());
-
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object obj) {
- if (this == obj)
- return true;
-
- if (obj == null || getClass() != obj.getClass())
- return false;
-
- KNNModel that = (KNNModel)obj;
-
- return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy)
- && Arrays.deepEquals(training.data(), that.training.data());
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java
deleted file mode 100644
index 11a23f5..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNModelFormat.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn.models;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import org.apache.ignite.ml.math.distances.DistanceMeasure;
-import org.apache.ignite.ml.structures.LabeledDataset;
-
-/**
- * kNN model representation.
- *
- * @see KNNModel
- */
-public class KNNModelFormat implements Serializable {
- /** Amount of nearest neighbors. */
- private int k;
-
- /** Distance measure. */
- private DistanceMeasure distanceMeasure;
-
- /** Training dataset */
- private LabeledDataset training;
-
- /** kNN strategy. */
- private KNNStrategy stgy;
-
- /** */
- public int getK() {
- return k;
- }
-
- /** */
- public DistanceMeasure getDistanceMeasure() {
- return distanceMeasure;
- }
-
- /** */
- public LabeledDataset getTraining() {
- return training;
- }
-
- /** */
- public KNNStrategy getStgy() {
- return stgy;
- }
-
- /** */
- public KNNModelFormat(int k, DistanceMeasure measure, LabeledDataset training, KNNStrategy stgy) {
- this.k = k;
- this.distanceMeasure = measure;
- this.training = training;
- this.stgy = stgy;
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = 1;
-
- res = res * 37 + k;
- res = res * 37 + distanceMeasure.hashCode();
- res = res * 37 + stgy.hashCode();
- res = res * 37 + Arrays.hashCode(training.data());
-
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object obj) {
- if (this == obj)
- return true;
-
- if (obj == null || getClass() != obj.getClass())
- return false;
-
- KNNModelFormat that = (KNNModelFormat)obj;
-
- return k == that.k && distanceMeasure.equals(that.distanceMeasure) && stgy.equals(that.stgy)
- && Arrays.deepEquals(training.data(), that.training.data());
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java
deleted file mode 100644
index d524773..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/KNNStrategy.java
+++ /dev/null
@@ -1,27 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn.models;
-
-/** This enum contains settings for kNN algorithm. */
-public enum KNNStrategy {
- /** */
- SIMPLE,
-
- /** */
- WEIGHTED
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java
deleted file mode 100644
index 7b6e678..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/models/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains main APIs for kNN classification algorithms.
- */
-package org.apache.ignite.ml.knn.models;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
new file mode 100644
index 0000000..0081612
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/KNNPartitionContext.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn.partitions;
+
+import java.io.Serializable;
+
+/**
+ * Partition context of the kNN classification algorithm.
+ */
+public class KNNPartitionContext implements Serializable {
+ /** */
+ private static final long serialVersionUID = -7212307112344430126L;
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
new file mode 100644
index 0000000..951a849
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/partitions/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains helper classes for kNN classification algorithms.
+ */
+package org.apache.ignite.ml.knn.partitions;
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java
deleted file mode 100644
index 1796eeb..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNMultipleLinearRegression.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.ignite.ml.knn.regression;
-
-import org.apache.ignite.ml.knn.models.KNNModel;
-import org.apache.ignite.ml.knn.models.KNNStrategy;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.distances.DistanceMeasure;
-import org.apache.ignite.ml.math.exceptions.UnsupportedOperationException;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledVector;
-
-/**
- * This class provides kNN Multiple Linear Regression or Locally [weighted] regression (Simple and Weighted versions).
- *
- * <p> This is an instance-based learning method. </p>
- *
- * <ul>
- * <li>Local means using nearby points (i.e. a nearest neighbors approach).</li>
- * <li>Weighted means we value points based upon how far away they are.</li>
- * <li>Regression means approximating a function.</li>
- * </ul>
- */
-public class KNNMultipleLinearRegression extends KNNModel {
- /** */
- public KNNMultipleLinearRegression(int k, DistanceMeasure distanceMeasure, KNNStrategy stgy,
- LabeledDataset training) {
- super(k, distanceMeasure, stgy, training);
- }
-
- /** {@inheritDoc} */
- @Override public Double apply(Vector v) {
- LabeledVector[] neighbors = findKNearestNeighbors(v, true);
-
- return predictYBasedOn(neighbors, v);
- }
-
- /** */
- private double predictYBasedOn(LabeledVector[] neighbors, Vector v) {
- switch (stgy) {
- case SIMPLE:
- return simpleRegression(neighbors);
- case WEIGHTED:
- return weightedRegression(neighbors, v);
- default:
- throw new UnsupportedOperationException("Strategy " + stgy.name() + " is not supported");
- }
- }
-
- /** */
- private double weightedRegression(LabeledVector<Vector, Double>[] neighbors, Vector v) {
- double sum = 0.0;
- double div = 0.0;
- for (int i = 0; i < neighbors.length; i++) {
- double distance = cachedDistances != null ? cachedDistances[i] : distanceMeasure.compute(v, neighbors[i].features());
- sum += neighbors[i].label() * distance;
- div += distance;
- }
- return sum / div;
- }
-
- /** */
- private double simpleRegression(LabeledVector<Vector, Double>[] neighbors) {
- double sum = 0.0;
- for (LabeledVector<Vector, Double> neighbor : neighbors)
- sum += neighbor.label();
- return sum / (double)k;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java
deleted file mode 100644
index 30023a1..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains main APIs for kNN regression algorithms.
- */
-package org.apache.ignite.ml.knn.regression;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
index 3fa2ec7..7ea8fb3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
@@ -36,4 +36,16 @@ public interface DistanceMeasure extends Externalizable {
* @throws CardinalityException if the array lengths differ.
*/
public double compute(Vector a, Vector b) throws CardinalityException;
+
+ /**
+ * Compute the distance between n-dimensional vector and n-dimensional array.
+ * <p>
+ * The two data structures are required to have the same dimension.
+ *
+ * @param a The vector.
+ * @param b The array.
+ * @return The distance between vector and array.
+ * @throws CardinalityException if the data structures lengths differ.
+ */
+ public double compute(Vector a, double[] b) throws CardinalityException;
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java
index a0c95d2..64ea285 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/EuclideanDistance.java
@@ -37,6 +37,16 @@ public class EuclideanDistance implements DistanceMeasure {
}
/** {@inheritDoc} */
+ @Override public double compute(Vector a, double[] b) throws CardinalityException {
+ double res = 0.0;
+
+ for (int i = 0; i < b.length; i++)
+ res+= Math.abs(b[i] - a.get(i));
+
+ return Math.sqrt(res);
+ }
+
+ /** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
// No-op
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java
index dec2d73..cb99074 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/HammingDistance.java
@@ -43,6 +43,11 @@ public class HammingDistance implements DistanceMeasure {
}
/** {@inheritDoc} */
+ @Override public double compute(Vector a, double[] b) throws CardinalityException {
+ throw new UnsupportedOperationException("It's not supported yet");
+ }
+
+ /** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
// No-op
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java
index 66394f1..9ea36b3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/ManhattanDistance.java
@@ -37,6 +37,11 @@ public class ManhattanDistance implements DistanceMeasure {
}
/** {@inheritDoc} */
+ @Override public double compute(Vector a, double[] b) throws CardinalityException {
+ throw new UnsupportedOperationException("It's not supported yet");
+ }
+
+ /** {@inheritDoc} */
@Override public void writeExternal(ObjectOutput out) throws IOException {
// No-op
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
index 1db3e8b..b1cc4c9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/isolve/lsqr/LSQROnHeap.java
@@ -108,4 +108,4 @@ public class LSQROnHeap<K, V> extends AbstractLSQR implements AutoCloseable {
@Override public void close() throws Exception {
dataset.close();
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
new file mode 100644
index 0000000..1069ff8
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionContext.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures.partition;
+
+import java.io.Serializable;
+
+/**
+ * Base partition context.
+ */
+public class LabelPartitionContext implements Serializable {
+ /** */
+ private static final long serialVersionUID = -7412302212344430126L;
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
new file mode 100644
index 0000000..14c053e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataBuilderOnHeap.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures.partition;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.structures.LabeledDataset;
+
+/**
+ * Partition data builder that builds {@link LabelPartitionDataOnHeap}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class LabelPartitionDataBuilderOnHeap<K, V, C extends Serializable>
+ implements PartitionDataBuilder<K, V, C, LabelPartitionDataOnHeap> {
+ /** */
+ private static final long serialVersionUID = -7820760153954269227L;
+
+ /** Extractor of Y vector value. */
+ private final IgniteBiFunction<K, V, Double> yExtractor;
+
+ /**
+ * Constructs a new instance of Label partition data builder.
+ *
+ * @param yExtractor Extractor of Y vector value.
+ */
+ public LabelPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, Double> yExtractor) {
+ this.yExtractor = yExtractor;
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabelPartitionDataOnHeap build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize,
+ C ctx) {
+ double[] y = new double[Math.toIntExact(upstreamDataSize)];
+
+ int ptr = 0;
+ while (upstreamData.hasNext()) {
+ UpstreamEntry<K, V> entry = upstreamData.next();
+
+ y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
+
+ ptr++;
+ }
+ return new LabelPartitionDataOnHeap(y);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataOnHeap.java
new file mode 100644
index 0000000..17dc835
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabelPartitionDataOnHeap.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures.partition;
+
+/**
+ * On Heap partition data that keeps part of a labels.
+ */
+public class LabelPartitionDataOnHeap implements AutoCloseable {
+ /** Part of Y vector. */
+ private final double[] y;
+
+ /**
+ * Constructs a new instance of linear system partition data.
+ *
+ * @param y Part of Y vector.
+ */
+ public LabelPartitionDataOnHeap(double[] y) {
+ this.y = y;
+ }
+
+ /** */
+ public double[] getY() {
+ return y;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ // Do nothing, GC will clean up.
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
new file mode 100644
index 0000000..b7f62ac
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures.partition;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVector;
+
+/**
+ * Partition data builder that builds {@link LabeledDataset}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializable>
+ implements PartitionDataBuilder<K, V, C, LabeledDataset<Double, LabeledVector>> {
+ /** */
+ private static final long serialVersionUID = -7820760153954269227L;
+
+ /** Extractor of X matrix row. */
+ private final IgniteBiFunction<K, V, double[]> xExtractor;
+
+ /** Extractor of Y vector value. */
+ private final IgniteBiFunction<K, V, Double> yExtractor;
+
+ /**
+ * Constructs a new instance of SVM partition data builder.
+ *
+ * @param xExtractor Extractor of X matrix row.
+ * @param yExtractor Extractor of Y vector value.
+ */
+ public LabeledDatasetPartitionDataBuilderOnHeap(IgniteBiFunction<K, V, double[]> xExtractor,
+ IgniteBiFunction<K, V, Double> yExtractor) {
+ this.xExtractor = xExtractor;
+ this.yExtractor = yExtractor;
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
+ long upstreamDataSize, C ctx) {
+ int xCols = -1;
+ double[][] x = null;
+ double[] y = new double[Math.toIntExact(upstreamDataSize)];
+
+ int ptr = 0;
+
+ while (upstreamData.hasNext()) {
+ UpstreamEntry<K, V> entry = upstreamData.next();
+ double[] row = xExtractor.apply(entry.getKey(), entry.getValue());
+
+ if (xCols < 0) {
+ xCols = row.length;
+ x = new double[Math.toIntExact(upstreamDataSize)][xCols];
+ }
+ else
+ assert row.length == xCols : "X extractor must return exactly " + xCols + " columns";
+
+ x[ptr] = row;
+
+ y[ptr] = yExtractor.apply(entry.getKey(), entry.getValue());
+
+ ptr++;
+ }
+ return new LabeledDataset<>(x, y);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 84f5eba..7f11e20 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -18,6 +18,7 @@
package org.apache.ignite.ml.svm;
import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
@@ -58,7 +59,7 @@ public class SVMLinearBinaryClassificationTrainer implements SingleLabelDatasetT
assert datasetBuilder != null;
- PartitionDataBuilder<K, V, SVMPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new SVMPartitionDataBuilderOnHeap<>(
+ PartitionDataBuilder<K, V, SVMPartitionContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
featureExtractor,
lbExtractor
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/43d05576/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
index cc1039f..88c342d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearMultiClassClassificationTrainer.java
@@ -29,9 +29,9 @@ import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.svm.multi.LabelPartitionContext;
-import org.apache.ignite.ml.svm.multi.LabelPartitionDataBuilderOnHeap;
-import org.apache.ignite.ml.svm.multi.LabelPartitionDataOnHeap;
+import org.apache.ignite.ml.structures.partition.LabelPartitionContext;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
/**
* Base class for a soft-margin SVM linear multiclass-classification trainer based on the communication-efficient