You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by av...@apache.org on 2018/08/15 22:04:03 UTC
[16/19] ignite git commit: IGNITE-9261: [ML] Add ANN algorithm based
on ACD concept
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
deleted file mode 100644
index 2440587..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDataset.java
+++ /dev/null
@@ -1,220 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.structures;
-
-import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.NoDataException;
-import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-
-/**
- * Class for set of labeled vectors.
- */
-public class LabeledDataset<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
- /**
- * Default constructor (required by Externalizable).
- */
- public LabeledDataset() {
- super();
- }
-
- /**
- * Creates new Labeled Dataset and initialized with empty data structure.
- *
- * @param rowSize Amount of instances. Should be > 0.
- * @param colSize Amount of attributes. Should be > 0.
- * @param isDistributed Use distributed data structures to keep data.
- */
- public LabeledDataset(int rowSize, int colSize, boolean isDistributed){
- this(rowSize, colSize, null, isDistributed);
- }
-
- /**
- * Creates new local Labeled Dataset and initialized with empty data structure.
- *
- * @param rowSize Amount of instances. Should be > 0.
- * @param colSize Amount of attributes. Should be > 0.
- */
- public LabeledDataset(int rowSize, int colSize){
- this(rowSize, colSize, null, false);
- }
-
- /**
- * Creates new Labeled Dataset and initialized with empty data structure.
- *
- * @param rowSize Amount of instances. Should be > 0.
- * @param colSize Amount of attributes. Should be > 0
- * @param featureNames Column names.
- * @param isDistributed Use distributed data structures to keep data.
- */
- public LabeledDataset(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
- super(rowSize, colSize, featureNames, isDistributed);
-
- initializeDataWithLabeledVectors();
- }
-
- /**
- * Creates new Labeled Dataset by given data.
- *
- * @param data Should be initialized with one vector at least.
- */
- public LabeledDataset(Row[] data) {
- super(data);
- }
-
- /** */
- private void initializeDataWithLabeledVectors() {
- data = (Row[])new LabeledVector[rowSize];
- for (int i = 0; i < rowSize; i++)
- data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), null);
- }
-
- /**
- * Creates new Labeled Dataset by given data.
- *
- * @param data Should be initialized with one vector at least.
- * @param colSize Amount of observed attributes in each vector.
- */
- public LabeledDataset(Row[] data, int colSize) {
- super(data, colSize);
- }
-
-
- /**
- * Creates new local Labeled Dataset by matrix and vector of labels.
- *
- * @param mtx Given matrix with rows as observations.
- * @param lbs Labels of observations.
- */
- public LabeledDataset(double[][] mtx, double[] lbs) {
- this(mtx, lbs, null, false);
- }
-
- /**
- * Creates new Labeled Dataset by matrix and vector of labels.
- *
- * @param mtx Given matrix with rows as observations.
- * @param lbs Labels of observations.
- * @param featureNames Column names.
- * @param isDistributed Use distributed data structures to keep data.
- */
- public LabeledDataset(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
- super();
- assert mtx != null;
- assert lbs != null;
-
- if(mtx.length != lbs.length)
- throw new CardinalityException(lbs.length, mtx.length);
-
- if(mtx[0] == null)
- throw new NoDataException("Pass filled array, the first vector is empty");
-
- this.rowSize = lbs.length;
- this.colSize = mtx[0].length;
-
- if(featureNames == null)
- generateFeatureNames();
- else {
- assert colSize == featureNames.length;
- convertStringNamesToFeatureMetadata(featureNames);
- }
-
- data = (Row[])new LabeledVector[rowSize];
- for (int i = 0; i < rowSize; i++){
-
- data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), lbs[i]);
- for (int j = 0; j < colSize; j++) {
- try {
- data[i].features().set(j, mtx[i][j]);
- } catch (ArrayIndexOutOfBoundsException e) {
- throw new NoDataException("No data in given matrix by coordinates (" + i + "," + j + ")");
- }
- }
- }
- }
-
- /**
- * Returns label if label is attached or null if label is missed.
- *
- * @param idx Index of observation.
- * @return Label.
- */
- public double label(int idx) {
- LabeledVector labeledVector = data[idx];
-
- if(labeledVector!=null)
- return (double)labeledVector.label();
- else
- return Double.NaN;
- }
-
- /**
- * Returns new copy of labels of all labeled vectors NOTE: This method is useful for copying labels from test
- * dataset.
- *
- * @return Copy of labels.
- */
- public double[] labels() {
- assert data != null;
- assert data.length > 0;
-
- double[] labels = new double[data.length];
-
- for (int i = 0; i < data.length; i++)
- labels[i] = (double)data[i].label();
-
- return labels;
- }
-
- /**
- * Fill the label with given value.
- *
- * @param idx Index of observation.
- * @param lb The given label.
- */
- public void setLabel(int idx, double lb) {
- LabeledVector<Vector, Double> labeledVector = data[idx];
-
- if(labeledVector != null)
- labeledVector.setLabel(lb);
- else
- throw new NoLabelVectorException(idx);
- }
-
- /** */
- public static Vector emptyVector(int size, boolean isDistributed) {
- return new DenseVector(size);
- }
-
- /** Makes copy with new Label objects and old features and Metadata objects. */
- public LabeledDataset copy(){
- LabeledDataset res = new LabeledDataset(this.data, this.colSize);
- res.isDistributed = this.isDistributed;
- res.meta = this.meta;
- for (int i = 0; i < rowSize; i++)
- res.setLabel(i, this.label(i));
-
- return res;
- }
-
- /** Closes LabeledDataset. */
- @Override public void close() throws Exception {
-
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
deleted file mode 100644
index f362fbc..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledDatasetTestTrainPair.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.structures;
-
-import java.io.Serializable;
-import java.util.Map;
-import java.util.Random;
-import java.util.TreeMap;
-import java.util.TreeSet;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Class for splitting Labeled Dataset on train and test sets.
- */
-public class LabeledDatasetTestTrainPair implements Serializable {
- /** Data to keep train set. */
- private LabeledDataset train;
-
- /** Data to keep test set. */
- private LabeledDataset test;
-
- /**
- * Creates two subsets of given dataset.
- * <p>
- * NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory.
- * </p>
- * @param dataset The dataset to split on train and test subsets.
- * @param testPercentage The percentage of the test subset.
- */
- public LabeledDatasetTestTrainPair(LabeledDataset dataset, double testPercentage) {
- assert testPercentage > 0.0;
- assert testPercentage < 1.0;
- final int datasetSize = dataset.rowSize();
- assert datasetSize > 2;
-
- final int testSize = (int)Math.floor(datasetSize * testPercentage);
- final int trainSize = datasetSize - testSize;
-
- final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize);
-
- LabeledVector[] testVectors = new LabeledVector[testSize];
- LabeledVector[] trainVectors = new LabeledVector[trainSize];
-
- int datasetCntr = 0;
- int trainCntr = 0;
- int testCntr = 0;
-
- for (Integer idx: sortedTestIndices){ // guarantee order as iterator
- testVectors[testCntr] = (LabeledVector)dataset.getRow(idx);
- testCntr++;
-
- for (int i = datasetCntr; i < idx; i++) {
- trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
- trainCntr++;
- }
-
- datasetCntr = idx + 1;
- }
- if(datasetCntr < datasetSize){
- for (int i = datasetCntr; i < datasetSize; i++) {
- trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
- trainCntr++;
- }
- }
-
- test = new LabeledDataset(testVectors, dataset.colSize());
- train = new LabeledDataset(trainVectors, dataset.colSize());
- }
-
- /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */
- @NotNull private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) {
- Random rnd = new Random();
- TreeMap<Double, Integer> randomIdxPairs = new TreeMap<>();
- for (int i = 0; i < datasetSize; i++)
- randomIdxPairs.put(rnd.nextDouble(), i);
-
- final TreeMap<Double, Integer> testIdxPairs = randomIdxPairs.entrySet().stream()
- .limit(testSize)
- .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll);
-
- return new TreeSet<>(testIdxPairs.values());
- }
-
- /**
- * Train subset of the whole dataset.
- * @return Train subset.
- */
- public LabeledDataset train() {
- return train;
- }
-
- /**
- * Test subset of the whole dataset.
- * @return Test subset.
- */
- public LabeledDataset test() {
- return test;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
new file mode 100644
index 0000000..e98d793
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSet.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures;
+
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.NoDataException;
+import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+
+/**
+ * The set of labeled vectors used in local partition calculations.
+ */
+public class LabeledVectorSet<L, Row extends LabeledVector> extends Dataset<Row> implements AutoCloseable {
+ /**
+ * Default constructor (required by Externalizable).
+ */
+ public LabeledVectorSet() {
+ super();
+ }
+
+ /**
+ * Creates new Labeled Dataset and initialized with empty data structure.
+ *
+ * @param rowSize Amount of instances. Should be > 0.
+ * @param colSize Amount of attributes. Should be > 0.
+ * @param isDistributed Use distributed data structures to keep data.
+ */
+ public LabeledVectorSet(int rowSize, int colSize, boolean isDistributed){
+ this(rowSize, colSize, null, isDistributed);
+ }
+
+ /**
+ * Creates new local Labeled Dataset and initialized with empty data structure.
+ *
+ * @param rowSize Amount of instances. Should be > 0.
+ * @param colSize Amount of attributes. Should be > 0.
+ */
+ public LabeledVectorSet(int rowSize, int colSize){
+ this(rowSize, colSize, null, false);
+ }
+
+ /**
+ * Creates new Labeled Dataset and initialized with empty data structure.
+ *
+ * @param rowSize Amount of instances. Should be > 0.
+ * @param colSize Amount of attributes. Should be > 0
+ * @param featureNames Column names.
+ * @param isDistributed Use distributed data structures to keep data.
+ */
+ public LabeledVectorSet(int rowSize, int colSize, String[] featureNames, boolean isDistributed){
+ super(rowSize, colSize, featureNames, isDistributed);
+
+ initializeDataWithLabeledVectors();
+ }
+
+ /**
+ * Creates new Labeled Dataset by given data.
+ *
+ * @param data Should be initialized with one vector at least.
+ */
+ public LabeledVectorSet(Row[] data) {
+ super(data);
+ }
+
+ /** */
+ private void initializeDataWithLabeledVectors() {
+ data = (Row[])new LabeledVector[rowSize];
+ for (int i = 0; i < rowSize; i++)
+ data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), null);
+ }
+
+ /**
+ * Creates new Labeled Dataset by given data.
+ *
+ * @param data Should be initialized with one vector at least.
+ * @param colSize Amount of observed attributes in each vector.
+ */
+ public LabeledVectorSet(Row[] data, int colSize) {
+ super(data, colSize);
+ }
+
+
+ /**
+ * Creates new local Labeled Dataset by matrix and vector of labels.
+ *
+ * @param mtx Given matrix with rows as observations.
+ * @param lbs Labels of observations.
+ */
+ public LabeledVectorSet(double[][] mtx, double[] lbs) {
+ this(mtx, lbs, null, false);
+ }
+
+ /**
+ * Creates new Labeled Dataset by matrix and vector of labels.
+ *
+ * @param mtx Given matrix with rows as observations.
+ * @param lbs Labels of observations.
+ * @param featureNames Column names.
+ * @param isDistributed Use distributed data structures to keep data.
+ */
+ public LabeledVectorSet(double[][] mtx, double[] lbs, String[] featureNames, boolean isDistributed) {
+ super();
+ assert mtx != null;
+ assert lbs != null;
+
+ if(mtx.length != lbs.length)
+ throw new CardinalityException(lbs.length, mtx.length);
+
+ if(mtx[0] == null)
+ throw new NoDataException("Pass filled array, the first vector is empty");
+
+ this.rowSize = lbs.length;
+ this.colSize = mtx[0].length;
+
+ if(featureNames == null)
+ generateFeatureNames();
+ else {
+ assert colSize == featureNames.length;
+ convertStringNamesToFeatureMetadata(featureNames);
+ }
+
+ data = (Row[])new LabeledVector[rowSize];
+ for (int i = 0; i < rowSize; i++){
+
+ data[i] = (Row)new LabeledVector(emptyVector(colSize, isDistributed), lbs[i]);
+ for (int j = 0; j < colSize; j++) {
+ try {
+ data[i].features().set(j, mtx[i][j]);
+ } catch (ArrayIndexOutOfBoundsException e) {
+ throw new NoDataException("No data in given matrix by coordinates (" + i + "," + j + ")");
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns label if label is attached or null if label is missed.
+ *
+ * @param idx Index of observation.
+ * @return Label.
+ */
+ public double label(int idx) {
+ LabeledVector labeledVector = data[idx];
+
+ if(labeledVector!=null)
+ return (double)labeledVector.label();
+ else
+ return Double.NaN;
+ }
+
+ /**
+ * Returns new copy of labels of all labeled vectors NOTE: This method is useful for copying labels from test
+ * dataset.
+ *
+ * @return Copy of labels.
+ */
+ public double[] labels() {
+ assert data != null;
+ assert data.length > 0;
+
+ double[] labels = new double[data.length];
+
+ for (int i = 0; i < data.length; i++)
+ labels[i] = (double)data[i].label();
+
+ return labels;
+ }
+
+ /**
+ * Fill the label with given value.
+ *
+ * @param idx Index of observation.
+ * @param lb The given label.
+ */
+ public void setLabel(int idx, double lb) {
+ LabeledVector<Vector, Double> labeledVector = data[idx];
+
+ if(labeledVector != null)
+ labeledVector.setLabel(lb);
+ else
+ throw new NoLabelVectorException(idx);
+ }
+
+ /** */
+ public static Vector emptyVector(int size, boolean isDistributed) {
+ return new DenseVector(size);
+ }
+
+ /** Makes copy with new Label objects and old features and Metadata objects. */
+ public LabeledVectorSet copy(){
+ LabeledVectorSet res = new LabeledVectorSet(this.data, this.colSize);
+ res.isDistributed = this.isDistributed;
+ res.meta = this.meta;
+ for (int i = 0; i < rowSize; i++)
+ res.setLabel(i, this.label(i));
+
+ return res;
+ }
+
+ /** Closes LabeledDataset. */
+ @Override public void close() throws Exception {
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
new file mode 100644
index 0000000..d06dfd0
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVectorSetTestTrainPair.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.structures;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Class for splitting Labeled Dataset on train and test sets.
+ */
+public class LabeledVectorSetTestTrainPair implements Serializable {
+ /** Data to keep train set. */
+ private LabeledVectorSet train;
+
+ /** Data to keep test set. */
+ private LabeledVectorSet test;
+
+ /**
+ * Creates two subsets of given dataset.
+ * <p>
+ * NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory.
+ * </p>
+ * @param dataset The dataset to split on train and test subsets.
+ * @param testPercentage The percentage of the test subset.
+ */
+ public LabeledVectorSetTestTrainPair(LabeledVectorSet dataset, double testPercentage) {
+ assert testPercentage > 0.0;
+ assert testPercentage < 1.0;
+ final int datasetSize = dataset.rowSize();
+ assert datasetSize > 2;
+
+ final int testSize = (int)Math.floor(datasetSize * testPercentage);
+ final int trainSize = datasetSize - testSize;
+
+ final TreeSet<Integer> sortedTestIndices = getSortedIndices(datasetSize, testSize);
+
+ LabeledVector[] testVectors = new LabeledVector[testSize];
+ LabeledVector[] trainVectors = new LabeledVector[trainSize];
+
+ int datasetCntr = 0;
+ int trainCntr = 0;
+ int testCntr = 0;
+
+ for (Integer idx: sortedTestIndices){ // guarantee order as iterator
+ testVectors[testCntr] = (LabeledVector)dataset.getRow(idx);
+ testCntr++;
+
+ for (int i = datasetCntr; i < idx; i++) {
+ trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
+ trainCntr++;
+ }
+
+ datasetCntr = idx + 1;
+ }
+ if(datasetCntr < datasetSize){
+ for (int i = datasetCntr; i < datasetSize; i++) {
+ trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
+ trainCntr++;
+ }
+ }
+
+ test = new LabeledVectorSet(testVectors, dataset.colSize());
+ train = new LabeledVectorSet(trainVectors, dataset.colSize());
+ }
+
+ /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */
+ @NotNull private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) {
+ Random rnd = new Random();
+ TreeMap<Double, Integer> randomIdxPairs = new TreeMap<>();
+ for (int i = 0; i < datasetSize; i++)
+ randomIdxPairs.put(rnd.nextDouble(), i);
+
+ final TreeMap<Double, Integer> testIdxPairs = randomIdxPairs.entrySet().stream()
+ .limit(testSize)
+ .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll);
+
+ return new TreeSet<>(testIdxPairs.values());
+ }
+
+ /**
+ * Train subset of the whole dataset.
+ * @return Train subset.
+ */
+ public LabeledVectorSet train() {
+ return train;
+ }
+
+ /**
+ * Test subset of the whole dataset.
+ * @return Test subset.
+ */
+ public LabeledVectorSet test() {
+ return test;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
index b4e552b..0351037 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/partition/LabeledDatasetPartitionDataBuilderOnHeap.java
@@ -23,18 +23,18 @@ import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
/**
- * Partition data builder that builds {@link LabeledDataset}.
+ * Partition data builder that builds {@link LabeledVectorSet}.
*
* @param <K> Type of a key in <tt>upstream</tt> data.
* @param <V> Type of a value in <tt>upstream</tt> data.
* @param <C> Type of a partition <tt>context</tt>.
*/
public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializable>
- implements PartitionDataBuilder<K, V, C, LabeledDataset<Double, LabeledVector>> {
+ implements PartitionDataBuilder<K, V, C, LabeledVectorSet<Double, LabeledVector>> {
/** */
private static final long serialVersionUID = -7820760153954269227L;
@@ -57,8 +57,8 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
}
/** {@inheritDoc} */
- @Override public LabeledDataset<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
- long upstreamDataSize, C ctx) {
+ @Override public LabeledVectorSet<Double, LabeledVector> build(Iterator<UpstreamEntry<K, V>> upstreamData,
+ long upstreamDataSize, C ctx) {
int xCols = -1;
double[][] x = null;
double[] y = new double[Math.toIntExact(upstreamDataSize)];
@@ -82,6 +82,6 @@ public class LabeledDatasetPartitionDataBuilderOnHeap<K, V, C extends Serializab
ptr++;
}
- return new LabeledDataset<>(x, y);
+ return new LabeledVectorSet<>(x, y);
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
index 5c20d9c..f370cbd 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/preprocessing/LabeledDatasetLoader.java
@@ -28,8 +28,8 @@ import org.apache.ignite.ml.math.exceptions.NoDataException;
import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.jetbrains.annotations.NotNull;
/** Data pre-processing step which loads data from different file types. */
@@ -43,8 +43,8 @@ public class LabeledDatasetLoader {
* @param isFallOnBadData Fall on incorrect data if true.
* @return Labeled Dataset parsed from file.
*/
- public static LabeledDataset loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
- boolean isFallOnBadData) throws IOException {
+ public static LabeledVectorSet loadFromTxtFile(Path pathToFile, String separator, boolean isDistributed,
+ boolean isFallOnBadData) throws IOException {
Stream<String> stream = Files.lines(pathToFile);
List<String> list = new ArrayList<>();
stream.forEach(list::add);
@@ -81,7 +81,7 @@ public class LabeledDatasetLoader {
for (int i = 0; i < vectors.size(); i++)
data[i] = new LabeledVector(vectors.get(i), labels.get(i));
- return new LabeledDataset(data, colSize);
+ return new LabeledVectorSet(data, colSize);
}
else
throw new NoDataException("File should contain first row with data");
@@ -93,7 +93,7 @@ public class LabeledDatasetLoader {
/** */
@NotNull private static Vector parseFeatures(Path pathToFile, boolean isDistributed, boolean isFallOnBadData,
int colSize, int rowIdx, String[] rowData) {
- final Vector vec = LabeledDataset.emptyVector(colSize, isDistributed);
+ final Vector vec = LabeledVectorSet.emptyVector(colSize, isDistributed);
if (isFallOnBadData && rowData.length != colSize + 1)
throw new CardinalityException(colSize + 1, rowData.length);
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
index 1ae896f..4f11318 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearBinaryClassificationTrainer.java
@@ -25,8 +25,8 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
-import org.apache.ignite.ml.structures.LabeledDataset;
import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
@@ -60,14 +60,14 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
assert datasetBuilder != null;
- PartitionDataBuilder<K, V, EmptyContext, LabeledDataset<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
+ PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(
featureExtractor,
lbExtractor
);
Vector weights;
- try(Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset = datasetBuilder.build(
+ try(Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = datasetBuilder.build(
(upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder
)) {
@@ -91,7 +91,7 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
}
/** */
- private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledDataset<Double, LabeledVector>> dataset) {
+ private Vector calculateUpdates(Vector weights, Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset) {
return dataset.compute(data -> {
Vector copiedWeights = weights.copy();
Vector deltaWeights = initializeWeightsWithZeros(weights.size());
@@ -116,8 +116,8 @@ public class SVMLinearBinaryClassificationTrainer extends SingleLabelDatasetTrai
}
/** */
- private Deltas getDeltas(LabeledDataset data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
- int randomIdx) {
+ private Deltas getDeltas(LabeledVectorSet data, Vector copiedWeights, int amountOfObservation, Vector tmpAlphas,
+ int randomIdx) {
LabeledVector row = (LabeledVector)data.getRow(randomIdx);
Double lb = (Double)row.label();
Vector v = makeVectorWithInterceptElement(row);
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
index 3e3bab5..42f5dec 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/LocalModelsTest.java
@@ -28,13 +28,20 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNModelFormat;
+import org.apache.ignite.ml.knn.ann.ProbableLabel;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.svm.SVMLinearBinaryClassificationModel;
import org.apache.ignite.ml.svm.SVMLinearMultiClassClassificationModel;
import org.junit.Assert;
@@ -165,10 +172,10 @@ public class LocalModelsTest {
@Test
public void importExportKNNModelTest() throws IOException {
executeModelTest(mdlFilePath -> {
- KNNClassificationModel mdl = new KNNClassificationModel(null)
+ NNClassificationModel mdl = new KNNClassificationModel(null)
.withK(3)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.SIMPLE);
+ .withStrategy(NNStrategy.SIMPLE);
Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
mdl.saveModel(exporter, mdlFilePath);
@@ -177,7 +184,37 @@ public class LocalModelsTest {
Assert.assertNotNull(load);
- KNNClassificationModel importedMdl = new KNNClassificationModel(null)
+ NNClassificationModel importedMdl = new KNNClassificationModel(null)
+ .withK(load.getK())
+ .withDistanceMeasure(load.getDistanceMeasure())
+ .withStrategy(load.getStgy());
+
+ Assert.assertTrue("", mdl.equals(importedMdl));
+
+ return null;
+ });
+ }
+
+ /** */
+ @Test
+ public void importExportANNModelTest() throws IOException {
+ executeModelTest(mdlFilePath -> {
+ final LabeledVectorSet<ProbableLabel, LabeledVector> centers = new LabeledVectorSet<>();
+
+ NNClassificationModel mdl = new ANNClassificationModel(centers)
+ .withK(4)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
+
+ Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
+ mdl.saveModel(exporter, mdlFilePath);
+
+ ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
+
+ Assert.assertNotNull(load);
+
+
+ NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates())
.withK(load.getK())
.withDistanceMeasure(load.getDistanceMeasure())
.withStrategy(load.getStgy());
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
index c4d896c..552c478 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/CollectionsTest.java
@@ -23,7 +23,7 @@ import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat;
import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNModelFormat;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.distances.HammingDistance;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
@@ -83,8 +83,8 @@ public class CollectionsTest {
test(new KMeansModel(new Vector[] {}, new ManhattanDistance()),
new KMeansModel(new Vector[] {}, new HammingDistance()));
- test(new KNNModelFormat(1, new ManhattanDistance(), KNNStrategy.SIMPLE),
- new KNNModelFormat(2, new ManhattanDistance(), KNNStrategy.SIMPLE));
+ test(new KNNModelFormat(1, new ManhattanDistance(), NNStrategy.SIMPLE),
+ new KNNModelFormat(2, new ManhattanDistance(), NNStrategy.SIMPLE));
test(new KNNClassificationModel(null).withK(1), new KNNClassificationModel(null).withK(2));
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
new file mode 100644
index 0000000..ea602cd
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+/** Tests behaviour of ANNClassificationTest. */
+@RunWith(Parameterized.class)
+public class ANNClassificationTest {
+ /** Number of parts to be tested. */
+ private static final int[] partsToBeTested = new int[]{1, 2, 3, 4, 5, 7, 100};
+
+ /** Fixed size of Dataset. */
+ private static final int AMOUNT_OF_OBSERVATIONS = 1000;
+
+ /** Fixed size of columns in Dataset. */
+ private static final int AMOUNT_OF_FEATURES = 2;
+
+ /** Precision in test checks. */
+ private static final double PRECISION = 1e-2;
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ /** Parameters. */
+ @Parameterized.Parameters(name = "Data divided on {0} partitions, training with batch size {1}")
+ public static Iterable<Integer[]> data() {
+ List<Integer[]> res = new ArrayList<>();
+
+ for (int part : partsToBeTested)
+ res.add(new Integer[]{part});
+
+ return res;
+ }
+
+ /** */
+ @Test
+ public void testBinaryClassificationTest() {
+ Map<Integer, double[]> data = new HashMap<>();
+
+ ThreadLocalRandom rndX = ThreadLocalRandom.current();
+ ThreadLocalRandom rndY = ThreadLocalRandom.current();
+
+ for (int i = 0; i < AMOUNT_OF_OBSERVATIONS; i++) {
+ double x = rndX.nextDouble(500, 600);
+ double y = rndY.nextDouble(500, 600);
+ double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+ vec[0] = 0; // assign label.
+ vec[1] = x;
+ vec[2] = y;
+ data.put(i, vec);
+ }
+
+ for (int i = AMOUNT_OF_OBSERVATIONS; i < AMOUNT_OF_OBSERVATIONS * 2; i++) {
+ double x = rndX.nextDouble(-600, -500);
+ double y = rndY.nextDouble(-600, -500);
+ double[] vec = new double[AMOUNT_OF_FEATURES + 1];
+ vec[0] = 1; // assign label.
+ vec[1] = x;
+ vec[2] = y;
+ data.put(i, vec);
+ }
+
+ ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+ .withK(10);
+
+ NNClassificationModel mdl = trainer.fit(
+ data,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ ).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(NNStrategy.SIMPLE);
+
+ TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(550, 550)), PRECISION);
+ TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-550, -550)), PRECISION);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
index aeb2414..c176682 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNClassificationTest.java
@@ -22,9 +22,8 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import org.apache.ignite.ml.knn.classification.KNNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -70,14 +69,14 @@ public class KNNClassificationTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- KNNClassificationModel knnMdl = trainer.fit(
+ NNClassificationModel knnMdl = trainer.fit(
data,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.SIMPLE);
+ .withStrategy(NNStrategy.SIMPLE);
assertTrue(knnMdl.toString().length() > 0);
assertTrue(knnMdl.toString(true).length() > 0);
@@ -102,14 +101,14 @@ public class KNNClassificationTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- KNNClassificationModel knnMdl = trainer.fit(
+ NNClassificationModel knnMdl = trainer.fit(
data,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]
).withK(1)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.SIMPLE);
+ .withStrategy(NNStrategy.SIMPLE);
Vector firstVector = new DenseVector(new double[] {2.0, 2.0});
assertEquals(knnMdl.apply(firstVector), 1.0);
@@ -130,14 +129,14 @@ public class KNNClassificationTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- KNNClassificationModel knnMdl = trainer.fit(
+ NNClassificationModel knnMdl = trainer.fit(
data,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.SIMPLE);
+ .withStrategy(NNStrategy.SIMPLE);
Vector vector = new DenseVector(new double[] {-1.01, -1.01});
assertEquals(knnMdl.apply(vector), 2.0);
@@ -156,14 +155,14 @@ public class KNNClassificationTest {
KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- KNNClassificationModel knnMdl = trainer.fit(
+ NNClassificationModel knnMdl = trainer.fit(
data,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.WEIGHTED);
+ .withStrategy(NNStrategy.WEIGHTED);
Vector vector = new DenseVector(new double[] {-1.01, -1.01});
assertEquals(knnMdl.apply(vector), 1.0);
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
index 7d57ec9..e05903e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNRegressionTest.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
-import org.apache.ignite.ml.knn.classification.KNNStrategy;
+import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
@@ -77,7 +77,7 @@ public class KNNRegressionTest {
(k, v) -> v[0]
).withK(1)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(KNNStrategy.SIMPLE);
+ .withStrategy(NNStrategy.SIMPLE);
Vector vector = new DenseVector(new double[] {0, 0, 0, 5.0, 0.0});
System.out.println(knnMdl.apply(vector));
@@ -87,17 +87,17 @@ public class KNNRegressionTest {
/** */
@Test
public void testLongly() {
- testLongly(KNNStrategy.SIMPLE);
+ testLongly(NNStrategy.SIMPLE);
}
/** */
@Test
public void testLonglyWithWeightedStrategy() {
- testLongly(KNNStrategy.WEIGHTED);
+ testLongly(NNStrategy.WEIGHTED);
}
/** */
- private void testLongly(KNNStrategy stgy) {
+ private void testLongly(NNStrategy stgy) {
Map<Integer, double[]> data = new HashMap<>();
data.put(0, new double[] {60323, 83.0, 234289, 2356, 1590, 107608, 1947});
data.put(1, new double[] {61122, 88.5, 259426, 2325, 1456, 108632, 1948});
@@ -123,16 +123,12 @@ public class KNNRegressionTest {
(k, v) -> v[0]
).withK(3)
.withDistanceMeasure(new EuclideanDistance())
- .withStrategy(stgy);
+ .withStrategy(NNStrategy.SIMPLE);
Vector vector = new DenseVector(new double[] {104.6, 419180, 2822, 2857, 118734, 1956});
Assert.assertNotNull(knnMdl.apply(vector));
Assert.assertEquals(67857, knnMdl.apply(vector), 2000);
-
- Assert.assertTrue(knnMdl.toString().contains(stgy.name()));
- Assert.assertTrue(knnMdl.toString(true).contains(stgy.name()));
- Assert.assertTrue(knnMdl.toString(false).contains(stgy.name()));
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
index 55ef24e..0303d26 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/KNNTestSuite.java
@@ -25,9 +25,10 @@ import org.junit.runners.Suite;
*/
@RunWith(Suite.class)
@Suite.SuiteClasses({
+ ANNClassificationTest.class,
KNNClassificationTest.class,
KNNRegressionTest.class,
- LabeledDatasetTest.class
+ LabeledVectorSetTest.class
})
public class KNNTestSuite {
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
index dbcdb99..f3b8b3a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetHelper.java
@@ -21,7 +21,7 @@ import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
-import org.apache.ignite.ml.structures.LabeledDataset;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
/**
@@ -37,7 +37,7 @@ public class LabeledDatasetHelper {
* @param rsrcPath path to dataset.
* @return null if path is incorrect.
*/
- public static LabeledDataset loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
+ public static LabeledVectorSet loadDatasetFromTxt(String rsrcPath, boolean isFallOnBadData) {
try {
Path path = Paths.get(LabeledDatasetHelper.class.getClassLoader().getResource(rsrcPath).toURI());
try {
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
deleted file mode 100644
index 9867fbe..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledDatasetTest.java
+++ /dev/null
@@ -1,294 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.knn;
-
-import java.io.IOException;
-import java.net.URISyntaxException;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.util.Objects;
-import org.apache.ignite.ml.math.ExternalizableTest;
-import org.apache.ignite.ml.math.exceptions.CardinalityException;
-import org.apache.ignite.ml.math.exceptions.NoDataException;
-import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
-import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
-import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledDataset;
-import org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
-import org.junit.Test;
-
-import static junit.framework.TestCase.assertEquals;
-import static junit.framework.TestCase.fail;
-
-/** Tests behaviour of LabeledDataset. */
-public class LabeledDatasetTest implements ExternalizableTest<LabeledDataset> {
- /** */
- private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
-
- /** */
- private static final String NO_DATA_TXT = "datasets/knn/no_data.txt";
-
- /** */
- private static final String EMPTY_TXT = "datasets/knn/empty.txt";
-
- /** */
- private static final String IRIS_INCORRECT_TXT = "datasets/knn/iris_incorrect.txt";
-
- /** */
- private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt";
-
- /** */
- @Test
- public void testFeatureNames() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- String[] featureNames = new String[] {"x", "y"};
- final LabeledDataset dataset = new LabeledDataset(mtx, lbs, featureNames, false);
-
- assertEquals(dataset.getFeatureName(0), "x");
- }
-
- /** */
- @Test
- public void testAccessMethods() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- final LabeledDataset dataset = new LabeledDataset(mtx, lbs, null, false);
-
- assertEquals(dataset.colSize(), 2);
- assertEquals(dataset.rowSize(), 6);
-
- assertEquals(dataset.label(0), lbs[0], 0);
-
- assertEquals(dataset.copy().colSize(), 2);
-
- @SuppressWarnings("unchecked")
- final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0);
-
- assertEquals(row.features().get(0), 1.0);
- assertEquals(row.label(), 1.0);
- dataset.setLabel(0, 2.0);
- assertEquals(row.label(), 2.0);
-
- assertEquals(0, new LabeledDataset().rowSize());
- assertEquals(1, new LabeledDataset(1, 2).rowSize());
- assertEquals(1, new LabeledDataset(1, 2, true).rowSize());
- assertEquals(1, new LabeledDataset(1, 2, null, true).rowSize());
- }
-
- /** */
- @Test
- public void testFailOnYNull() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {};
-
- try {
- new LabeledDataset(mtx, lbs);
- fail("CardinalityException");
- }
- catch (CardinalityException e) {
- return;
- }
- fail("CardinalityException");
- }
-
- /** */
- @Test
- public void testFailOnXNull() {
- double[][] mtx =
- new double[][] {};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- try {
- new LabeledDataset(mtx, lbs);
- fail("CardinalityException");
- }
- catch (CardinalityException e) {
- return;
- }
- fail("CardinalityException");
- }
-
- /** */
- @Test
- public void testLoadingCorrectTxtFile() {
- LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
- assertEquals(Objects.requireNonNull(training).rowSize(), 150);
- }
-
- /** */
- @Test
- public void testLoadingEmptyFile() {
- try {
- LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
- fail("EmptyFileException");
- }
- catch (EmptyFileException e) {
- return;
- }
- fail("EmptyFileException");
- }
-
- /** */
- @Test
- public void testLoadingFileWithFirstEmptyRow() {
- try {
- LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
- fail("NoDataException");
- }
- catch (NoDataException e) {
- return;
- }
- fail("NoDataException");
- }
-
- /** */
- @Test
- public void testLoadingFileWithIncorrectData() {
- LabeledDataset training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
- assertEquals(149, Objects.requireNonNull(training).rowSize());
- }
-
- /** */
- @Test
- public void testFailOnLoadingFileWithIncorrectData() {
- try {
- LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
- fail("FileParsingException");
- }
- catch (FileParsingException e) {
- return;
- }
- fail("FileParsingException");
-
- }
-
- /** */
- @Test
- public void testLoadingFileWithMissedData() throws URISyntaxException, IOException {
- Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI());
-
- LabeledDataset training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
-
- assertEquals(training.features(2).get(1), 0.0);
- }
-
- /** */
- @Test
- public void testSplitting() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- LabeledDataset training = new LabeledDataset(mtx, lbs);
-
- LabeledDatasetTestTrainPair split1 = new LabeledDatasetTestTrainPair(training, 0.67);
-
- assertEquals(4, split1.test().rowSize());
- assertEquals(2, split1.train().rowSize());
-
- LabeledDatasetTestTrainPair split2 = new LabeledDatasetTestTrainPair(training, 0.65);
-
- assertEquals(3, split2.test().rowSize());
- assertEquals(3, split2.train().rowSize());
-
- LabeledDatasetTestTrainPair split3 = new LabeledDatasetTestTrainPair(training, 0.4);
-
- assertEquals(2, split3.test().rowSize());
- assertEquals(4, split3.train().rowSize());
-
- LabeledDatasetTestTrainPair split4 = new LabeledDatasetTestTrainPair(training, 0.3);
-
- assertEquals(1, split4.test().rowSize());
- assertEquals(5, split4.train().rowSize());
- }
-
- /** */
- @Test
- public void testLabels() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- LabeledDataset dataset = new LabeledDataset(mtx, lbs);
- final double[] labels = dataset.labels();
- for (int i = 0; i < lbs.length; i++)
- assertEquals(lbs[i], labels[i]);
- }
-
- /** */
- @Test(expected = NoLabelVectorException.class)
- @SuppressWarnings("unchecked")
- public void testSetLabelInvalid() {
- new LabeledDataset(new LabeledVector[1]).setLabel(0, 2.0);
- }
-
- /** */
- @Override public void testExternalization() {
- double[][] mtx =
- new double[][] {
- {1.0, 1.0},
- {1.0, 2.0},
- {2.0, 1.0},
- {-1.0, -1.0},
- {-1.0, -2.0},
- {-2.0, -1.0}};
- double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
-
- LabeledDataset dataset = new LabeledDataset(mtx, lbs);
- this.externalizeTest(dataset);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/82131d23/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
new file mode 100644
index 0000000..2303e96
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/LabeledVectorSetTest.java
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.knn;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Objects;
+import org.apache.ignite.ml.math.ExternalizableTest;
+import org.apache.ignite.ml.math.exceptions.CardinalityException;
+import org.apache.ignite.ml.math.exceptions.NoDataException;
+import org.apache.ignite.ml.math.exceptions.knn.EmptyFileException;
+import org.apache.ignite.ml.math.exceptions.knn.FileParsingException;
+import org.apache.ignite.ml.math.exceptions.knn.NoLabelVectorException;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.structures.LabeledVectorSet;
+import org.apache.ignite.ml.structures.LabeledVectorSetTestTrainPair;
+import org.apache.ignite.ml.structures.preprocessing.LabeledDatasetLoader;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.fail;
+
+/** Tests behaviour of KNNClassificationTest. */
+public class LabeledVectorSetTest implements ExternalizableTest<LabeledVectorSet> {
+ /** */
+ private static final String KNN_IRIS_TXT = "datasets/knn/iris.txt";
+
+ /** */
+ private static final String NO_DATA_TXT = "datasets/knn/no_data.txt";
+
+ /** */
+ private static final String EMPTY_TXT = "datasets/knn/empty.txt";
+
+ /** */
+ private static final String IRIS_INCORRECT_TXT = "datasets/knn/iris_incorrect.txt";
+
+ /** */
+ private static final String IRIS_MISSED_DATA = "datasets/knn/missed_data.txt";
+
+ /** */
+ @Test
+ public void testFeatureNames() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ String[] featureNames = new String[] {"x", "y"};
+ final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, featureNames, false);
+
+ assertEquals(dataset.getFeatureName(0), "x");
+ }
+
+ /** */
+ @Test
+ public void testAccessMethods() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ final LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs, null, false);
+
+ assertEquals(dataset.colSize(), 2);
+ assertEquals(dataset.rowSize(), 6);
+
+ assertEquals(dataset.label(0), lbs[0], 0);
+
+ assertEquals(dataset.copy().colSize(), 2);
+
+ @SuppressWarnings("unchecked")
+ final LabeledVector<Vector, Double> row = (LabeledVector<Vector, Double>)dataset.getRow(0);
+
+ assertEquals(row.features().get(0), 1.0);
+ assertEquals(row.label(), 1.0);
+ dataset.setLabel(0, 2.0);
+ assertEquals(row.label(), 2.0);
+
+ assertEquals(0, new LabeledVectorSet().rowSize());
+ assertEquals(1, new LabeledVectorSet(1, 2).rowSize());
+ assertEquals(1, new LabeledVectorSet(1, 2, true).rowSize());
+ assertEquals(1, new LabeledVectorSet(1, 2, null, true).rowSize());
+ }
+
+ /** */
+ @Test
+ public void testFailOnYNull() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {};
+
+ try {
+ new LabeledVectorSet(mtx, lbs);
+ fail("CardinalityException");
+ }
+ catch (CardinalityException e) {
+ return;
+ }
+ fail("CardinalityException");
+ }
+
+ /** */
+ @Test
+ public void testFailOnXNull() {
+ double[][] mtx =
+ new double[][] {};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ try {
+ new LabeledVectorSet(mtx, lbs);
+ fail("CardinalityException");
+ }
+ catch (CardinalityException e) {
+ return;
+ }
+ fail("CardinalityException");
+ }
+
+ /** */
+ @Test
+ public void testLoadingCorrectTxtFile() {
+ LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(KNN_IRIS_TXT, false);
+ assertEquals(training.rowSize(), 150);
+ }
+
+ /** */
+ @Test
+ public void testLoadingEmptyFile() {
+ try {
+ LabeledDatasetHelper.loadDatasetFromTxt(EMPTY_TXT, false);
+ fail("EmptyFileException");
+ }
+ catch (EmptyFileException e) {
+ return;
+ }
+ fail("EmptyFileException");
+ }
+
+ /** */
+ @Test
+ public void testLoadingFileWithFirstEmptyRow() {
+ try {
+ LabeledDatasetHelper.loadDatasetFromTxt(NO_DATA_TXT, false);
+ fail("NoDataException");
+ }
+ catch (NoDataException e) {
+ return;
+ }
+ fail("NoDataException");
+ }
+
+ /** */
+ @Test
+ public void testLoadingFileWithIncorrectData() {
+ LabeledVectorSet training = LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, false);
+ assertEquals(149, training.rowSize());
+ }
+
+ /** */
+ @Test
+ public void testFailOnLoadingFileWithIncorrectData() {
+ try {
+ LabeledDatasetHelper.loadDatasetFromTxt(IRIS_INCORRECT_TXT, true);
+ fail("FileParsingException");
+ }
+ catch (FileParsingException e) {
+ return;
+ }
+ fail("FileParsingException");
+
+ }
+
+ /** */
+ @Test
+ public void testLoadingFileWithMissedData() throws URISyntaxException, IOException {
+ Path path = Paths.get(Objects.requireNonNull(this.getClass().getClassLoader().getResource(IRIS_MISSED_DATA)).toURI());
+
+ LabeledVectorSet training = LabeledDatasetLoader.loadFromTxtFile(path, ",", false, false);
+
+ assertEquals(training.features(2).get(1), 0.0);
+ }
+
+ /** */
+ @Test
+ public void testSplitting() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ LabeledVectorSet training = new LabeledVectorSet(mtx, lbs);
+
+ LabeledVectorSetTestTrainPair split1 = new LabeledVectorSetTestTrainPair(training, 0.67);
+
+ assertEquals(4, split1.test().rowSize());
+ assertEquals(2, split1.train().rowSize());
+
+ LabeledVectorSetTestTrainPair split2 = new LabeledVectorSetTestTrainPair(training, 0.65);
+
+ assertEquals(3, split2.test().rowSize());
+ assertEquals(3, split2.train().rowSize());
+
+ LabeledVectorSetTestTrainPair split3 = new LabeledVectorSetTestTrainPair(training, 0.4);
+
+ assertEquals(2, split3.test().rowSize());
+ assertEquals(4, split3.train().rowSize());
+
+ LabeledVectorSetTestTrainPair split4 = new LabeledVectorSetTestTrainPair(training, 0.3);
+
+ assertEquals(1, split4.test().rowSize());
+ assertEquals(5, split4.train().rowSize());
+ }
+
+ /** */
+ @Test
+ public void testLabels() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
+ final double[] labels = dataset.labels();
+ for (int i = 0; i < lbs.length; i++)
+ assertEquals(lbs[i], labels[i]);
+ }
+
+ /** */
+ @Test(expected = NoLabelVectorException.class)
+ @SuppressWarnings("unchecked")
+ public void testSetLabelInvalid() {
+ new LabeledVectorSet(new LabeledVector[1]).setLabel(0, 2.0);
+ }
+
+ /** */
+ @Override public void testExternalization() {
+ double[][] mtx =
+ new double[][] {
+ {1.0, 1.0},
+ {1.0, 2.0},
+ {2.0, 1.0},
+ {-1.0, -1.0},
+ {-1.0, -2.0},
+ {-2.0, -1.0}};
+ double[] lbs = new double[] {1.0, 1.0, 1.0, 2.0, 2.0, 2.0};
+
+ LabeledVectorSet dataset = new LabeledVectorSet(mtx, lbs);
+ this.externalizeTest(dataset);
+ }
+}