You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ag...@apache.org on 2018/04/13 09:33:27 UTC
[15/54] [abbrv] ignite git commit: IGNITE-8059: Integrate decision
tree with partition based dataset.
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java
deleted file mode 100644
index 66c54f2..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/VarianceSplitCalculator.java
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs;
-
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.PrimitiveIterator;
-import java.util.stream.DoubleStream;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo;
-
-/**
- * Calculator of variance in a given region.
- */
-public class VarianceSplitCalculator implements ContinuousSplitCalculator<VarianceSplitCalculator.VarianceData> {
- /**
- * Data used in variance calculations.
- */
- public static class VarianceData extends ContinuousRegionInfo {
- /** Mean value in a given region. */
- double mean;
-
- /**
- * @param var Variance in this region.
- * @param size Size of data for which variance is calculated.
- * @param mean Mean value in this region.
- */
- public VarianceData(double var, int size, double mean) {
- super(var, size);
- this.mean = mean;
- }
-
- /**
- * No-op constructor. For serialization/deserialization.
- */
- public VarianceData() {
- // No-op.
- }
-
- /** {@inheritDoc} */
- @Override public void writeExternal(ObjectOutput out) throws IOException {
- super.writeExternal(out);
- out.writeDouble(mean);
- }
-
- /** {@inheritDoc} */
- @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- super.readExternal(in);
- mean = in.readDouble();
- }
-
- /**
- * Returns mean.
- */
- public double mean() {
- return mean;
- }
- }
-
- /** {@inheritDoc} */
- @Override public VarianceData calculateRegionInfo(DoubleStream s, int size) {
- PrimitiveIterator.OfDouble itr = s.iterator();
- int i = 0;
-
- double mean = 0.0;
- double m2 = 0.0;
-
- // Here we calculate variance and mean by incremental computation.
- while (itr.hasNext()) {
- i++;
- double x = itr.next();
- double delta = x - mean;
- mean += delta / i;
- double delta2 = x - mean;
- m2 += delta * delta2;
- }
-
- return new VarianceData(m2 / i, size, mean);
- }
-
- /** {@inheritDoc} */
- @Override public SplitInfo<VarianceData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx,
- VarianceData d) {
- int size = d.getSize();
-
- double lm2 = 0.0;
- double rm2 = d.impurity() * size;
- int lSize = size;
-
- double lMean = 0.0;
- double rMean = d.mean;
-
- double minImpurity = d.impurity() * size;
- double curThreshold;
- double curImpurity;
- double threshold = Double.NEGATIVE_INFINITY;
-
- int i = 0;
- int nextIdx = s[0];
- i++;
- double[] lrImps = new double[] {lm2, rm2, lMean, rMean};
-
- do {
- // Process all values equal to prev.
- while (i < s.length) {
- moveLeft(labels[nextIdx], lrImps[2], i, lrImps[0], lrImps[3], size - i, lrImps[1], lrImps);
- curImpurity = (lrImps[0] + lrImps[1]);
- curThreshold = values[nextIdx];
-
- if (values[nextIdx] != values[(nextIdx = s[i++])]) {
- if (curImpurity < minImpurity) {
- lSize = i - 1;
-
- lm2 = lrImps[0];
- rm2 = lrImps[1];
-
- lMean = lrImps[2];
- rMean = lrImps[3];
-
- minImpurity = curImpurity;
- threshold = curThreshold;
- }
-
- break;
- }
- }
- }
- while (i < s.length - 1);
-
- if (lSize == size)
- return null;
-
- VarianceData lData = new VarianceData(lm2 / (lSize != 0 ? lSize : 1), lSize, lMean);
- int rSize = size - lSize;
- VarianceData rData = new VarianceData(rm2 / (rSize != 0 ? rSize : 1), rSize, rMean);
-
- return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData);
- }
-
- /**
- * Add point to the left interval and remove it from the right interval and calculate necessary statistics on
- * intervals with new bounds.
- */
- private void moveLeft(double x, double lMean, int lSize, double lm2, double rMean, int rSize, double rm2,
- double[] data) {
- // We add point to the left interval.
- double lDelta = x - lMean;
- double lMeanNew = lMean + lDelta / lSize;
- double lm2New = lm2 + lDelta * (x - lMeanNew);
-
- // We remove point from the right interval. lSize + 1 is the size of right interval before removal.
- double rMeanNew = (rMean * (rSize + 1) - x) / rSize;
- double rm2New = rm2 - (x - rMean) * (x - rMeanNew);
-
- data[0] = lm2New;
- data[1] = rm2New;
-
- data[2] = lMeanNew;
- data[3] = rMeanNew;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java
deleted file mode 100644
index 08c8a75..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Calculators of splits by continuous features.
- */
-package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java
deleted file mode 100644
index 8523914..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains column based decision tree algorithms.
- */
-package org.apache.ignite.ml.trees.trainers.columnbased;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java
deleted file mode 100644
index 5c4b354..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/RegionCalculators.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs;
-
-import it.unimi.dsi.fastutil.doubles.Double2IntOpenHashMap;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.PrimitiveIterator;
-import java.util.stream.DoubleStream;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput;
-
-/** Some commonly used functions for calculations of regions of space which correspond to decision tree leaf nodes. */
-public class RegionCalculators {
- /** Mean value in the region. */
- public static final IgniteFunction<DoubleStream, Double> MEAN = s -> s.average().orElse(0.0);
-
- /** Most common value in the region. */
- public static final IgniteFunction<DoubleStream, Double> MOST_COMMON =
- s -> {
- PrimitiveIterator.OfDouble itr = s.iterator();
- Map<Double, Integer> voc = new HashMap<>();
-
- while (itr.hasNext())
- voc.compute(itr.next(), (d, i) -> i != null ? i + 1 : 0);
-
- return voc.entrySet().stream().max(Comparator.comparing(Map.Entry::getValue)).map(Map.Entry::getKey).orElse(0.0);
- };
-
- /** Variance of a region. */
- public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> VARIANCE = input ->
- s -> {
- PrimitiveIterator.OfDouble itr = s.iterator();
- int i = 0;
-
- double mean = 0.0;
- double m2 = 0.0;
-
- while (itr.hasNext()) {
- i++;
- double x = itr.next();
- double delta = x - mean;
- mean += delta / i;
- double delta2 = x - mean;
- m2 += delta * delta2;
- }
-
- return i > 0 ? m2 / i : 0.0;
- };
-
- /** Gini impurity of a region. */
- public static final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> GINI = input ->
- s -> {
- PrimitiveIterator.OfDouble itr = s.iterator();
-
- Double2IntOpenHashMap m = new Double2IntOpenHashMap();
-
- int size = 0;
-
- while (itr.hasNext()) {
- size++;
- m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1);
- }
-
- double c2 = m.values().stream().mapToDouble(v -> v * v).sum();
-
- return size != 0 ? 1 - c2 / (size * size) : 0.0;
- };
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java
deleted file mode 100644
index e8edd8f..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/regcalcs/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Region calculators.
- */
-package org.apache.ignite.ml.trees.trainers.columnbased.regcalcs;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java
deleted file mode 100644
index 3232ac2..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/CategoricalFeatureProcessor.java
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.DoubleStream;
-import java.util.stream.Stream;
-import java.util.stream.StreamSupport;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trees.CategoricalRegionInfo;
-import org.apache.ignite.ml.trees.CategoricalSplitInfo;
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection;
-
-import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet;
-
-/**
- * Categorical feature vector processor implementation used by {@link ColumnDecisionTreeTrainer}.
- */
-public class CategoricalFeatureProcessor
- implements FeatureProcessor<CategoricalRegionInfo, CategoricalSplitInfo<CategoricalRegionInfo>> {
- /** Count of categories for this feature. */
- private final int catsCnt;
-
- /** Function for calculating impurity of a given region of points. */
- private final IgniteFunction<DoubleStream, Double> calc;
-
- /**
- * @param calc Function for calculating impurity of a given region of points.
- * @param catsCnt Number of categories.
- */
- public CategoricalFeatureProcessor(IgniteFunction<DoubleStream, Double> calc, int catsCnt) {
- this.calc = calc;
- this.catsCnt = catsCnt;
- }
-
- /** */
- private SplitInfo<CategoricalRegionInfo> split(BitSet leftCats, int intervalIdx, Map<Integer, Integer> mapping,
- Integer[] sampleIndexes, double[] values, double[] labels, double impurity) {
- Map<Boolean, List<Integer>> leftRight = Arrays.stream(sampleIndexes).
- collect(Collectors.partitioningBy((smpl) -> leftCats.get(mapping.get((int)values[smpl]))));
-
- List<Integer> left = leftRight.get(true);
- int leftSize = left.size();
- double leftImpurity = calc.apply(left.stream().mapToDouble(s -> labels[s]));
-
- List<Integer> right = leftRight.get(false);
- int rightSize = right.size();
- double rightImpurity = calc.apply(right.stream().mapToDouble(s -> labels[s]));
-
- int totalSize = leftSize + rightSize;
-
- // Result of this call will be sent back to trainer node, we do not need vectors inside of sent data.
- CategoricalSplitInfo<CategoricalRegionInfo> res = new CategoricalSplitInfo<>(intervalIdx,
- new CategoricalRegionInfo(leftImpurity, null), // cats can be computed on the last step.
- new CategoricalRegionInfo(rightImpurity, null),
- leftCats);
-
- res.setInfoGain(impurity - (double)leftSize / totalSize * leftImpurity - (double)rightSize / totalSize * rightImpurity);
- return res;
- }
-
- /**
- * Get a stream of subsets given categories count.
- *
- * @param catsCnt categories count.
- * @return Stream of subsets given categories count.
- */
- private Stream<BitSet> powerSet(int catsCnt) {
- Iterable<BitSet> iterable = () -> new PSI(catsCnt);
- return StreamSupport.stream(iterable.spliterator(), false);
- }
-
- /** {@inheritDoc} */
- @Override public SplitInfo findBestSplit(RegionProjection<CategoricalRegionInfo> regionPrj, double[] values,
- double[] labels, int regIdx) {
- Map<Integer, Integer> mapping = mapping(regionPrj.data().cats());
-
- return powerSet(regionPrj.data().cats().length()).
- map(s -> split(s, regIdx, mapping, regionPrj.sampleIndexes(), values, labels, regionPrj.data().impurity())).
- max(Comparator.comparingDouble(SplitInfo::infoGain)).
- orElse(null);
- }
-
- /** {@inheritDoc} */
- @Override public RegionProjection<CategoricalRegionInfo> createInitialRegion(Integer[] sampleIndexes,
- double[] values, double[] labels) {
- BitSet set = new BitSet();
- set.set(0, catsCnt);
-
- Double impurity = calc.apply(Arrays.stream(labels));
-
- return new RegionProjection<>(sampleIndexes, new CategoricalRegionInfo(impurity, set), 0);
- }
-
- /** {@inheritDoc} */
- @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<CategoricalRegionInfo> regionPrj,
- double[] values,
- CategoricalSplitInfo<CategoricalRegionInfo> s) {
- SparseBitSet res = new SparseBitSet();
- Arrays.stream(regionPrj.sampleIndexes()).forEach(smpl -> res.set(smpl, s.bitSet().get((int)values[smpl])));
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs,
- RegionProjection<CategoricalRegionInfo> reg, CategoricalRegionInfo leftData, CategoricalRegionInfo rightData) {
- return performSplitGeneric(bs, null, reg, leftData, rightData);
- }
-
- /** {@inheritDoc} */
- @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(
- SparseBitSet bs, double[] values, RegionProjection<CategoricalRegionInfo> reg, RegionInfo leftData,
- RegionInfo rightData) {
- int depth = reg.depth();
-
- int lSize = bs.cardinality();
- int rSize = reg.sampleIndexes().length - lSize;
- IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
- BitSet leftCats = calculateCats(lrSamples.get1(), values);
- CategoricalRegionInfo lInfo = new CategoricalRegionInfo(leftData.impurity(), leftCats);
-
- // TODO: IGNITE-5892 Check how it will work with sparse data.
- BitSet rightCats = calculateCats(lrSamples.get2(), values);
- CategoricalRegionInfo rInfo = new CategoricalRegionInfo(rightData.impurity(), rightCats);
-
- RegionProjection<CategoricalRegionInfo> rPrj = new RegionProjection<>(lrSamples.get2(), rInfo, depth + 1);
- RegionProjection<CategoricalRegionInfo> lPrj = new RegionProjection<>(lrSamples.get1(), lInfo, depth + 1);
- return new IgniteBiTuple<>(lPrj, rPrj);
- }
-
- /**
- * Powerset iterator. Iterates not over the whole powerset, but on half of it.
- */
- private static class PSI implements Iterator<BitSet> {
-
- /** Current subset number. */
- private int i = 1; // We are not interested in {emptyset, set} split and therefore start from 1.
-
- /** Size of set, subsets of which we iterate over. */
- final int size;
-
- /**
- * @param bitCnt Size of set, subsets of which we iterate over.
- */
- PSI(int bitCnt) {
- this.size = 1 << (bitCnt - 1);
- }
-
- /** {@inheritDoc} */
- @Override public boolean hasNext() {
- return i < size;
- }
-
- /** {@inheritDoc} */
- @Override public BitSet next() {
- BitSet res = BitSet.valueOf(new long[] {i});
- i++;
- return res;
- }
- }
-
- /** */
- private Map<Integer, Integer> mapping(BitSet bs) {
- int bn = 0;
- Map<Integer, Integer> res = new HashMap<>();
-
- int i = 0;
- while ((bn = bs.nextSetBit(bn)) != -1) {
- res.put(bn, i);
- i++;
- bn++;
- }
-
- return res;
- }
-
- /** Get set of categories of given samples */
- private BitSet calculateCats(Integer[] sampleIndexes, double[] values) {
- BitSet res = new BitSet();
-
- for (int smpl : sampleIndexes)
- res.set((int)values[smpl]);
-
- return res;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java
deleted file mode 100644
index 4117993..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousFeatureProcessor.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import java.util.Arrays;
-import java.util.Comparator;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection;
-
-import static org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureVectorProcessorUtils.splitByBitSet;
-
-/**
- * Container of projection of samples on continuous feature.
- *
- * @param <D> Information about regions. Designed to contain information which will make computations of impurity
- * optimal.
- */
-public class ContinuousFeatureProcessor<D extends ContinuousRegionInfo> implements
- FeatureProcessor<D, ContinuousSplitInfo<D>> {
- /** ContinuousSplitCalculator used for calculating of best split of each region. */
- private final ContinuousSplitCalculator<D> calc;
-
- /**
- * @param splitCalc Calculator used for calculating splits.
- */
- public ContinuousFeatureProcessor(ContinuousSplitCalculator<D> splitCalc) {
- this.calc = splitCalc;
- }
-
- /** {@inheritDoc} */
- @Override public SplitInfo<D> findBestSplit(RegionProjection<D> ri, double[] values, double[] labels, int regIdx) {
- SplitInfo<D> res = calc.splitRegion(ri.sampleIndexes(), values, labels, regIdx, ri.data());
-
- if (res == null)
- return null;
-
- double lWeight = (double)res.leftData.getSize() / ri.sampleIndexes().length;
- double rWeight = (double)res.rightData.getSize() / ri.sampleIndexes().length;
-
- double infoGain = ri.data().impurity() - lWeight * res.leftData().impurity() - rWeight * res.rightData().impurity();
- res.setInfoGain(infoGain);
-
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels) {
- Arrays.sort(samples, Comparator.comparingDouble(s -> values[s]));
- return new RegionProjection<>(samples, calc.calculateRegionInfo(Arrays.stream(labels), samples.length), 0);
- }
-
- /** {@inheritDoc} */
- @Override public SparseBitSet calculateOwnershipBitSet(RegionProjection<D> reg, double[] values,
- ContinuousSplitInfo<D> s) {
- SparseBitSet res = new SparseBitSet();
-
- for (int i = 0; i < s.leftData().getSize(); i++)
- res.set(reg.sampleIndexes()[i]);
-
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs,
- RegionProjection<D> reg, D leftData, D rightData) {
- int lSize = leftData.getSize();
- int rSize = rightData.getSize();
- int depth = reg.depth();
-
- IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
-
- RegionProjection<D> left = new RegionProjection<>(lrSamples.get1(), leftData, depth + 1);
- RegionProjection<D> right = new RegionProjection<>(lrSamples.get2(), rightData, depth + 1);
-
- return new IgniteBiTuple<>(left, right);
- }
-
- /** {@inheritDoc} */
- @Override public IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs,
- double[] labels, RegionProjection<D> reg, RegionInfo leftData, RegionInfo rightData) {
- int lSize = bs.cardinality();
- int rSize = reg.sampleIndexes().length - lSize;
- int depth = reg.depth();
-
- IgniteBiTuple<Integer[], Integer[]> lrSamples = splitByBitSet(lSize, rSize, reg.sampleIndexes(), bs);
-
- D ld = calc.calculateRegionInfo(Arrays.stream(lrSamples.get1()).mapToDouble(s -> labels[s]), lSize);
- D rd = calc.calculateRegionInfo(Arrays.stream(lrSamples.get2()).mapToDouble(s -> labels[s]), rSize);
-
- return new IgniteBiTuple<>(new RegionProjection<>(lrSamples.get1(), ld, depth + 1), new RegionProjection<>(lrSamples.get2(), rd, depth + 1));
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java
deleted file mode 100644
index 8b45cb5..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/ContinuousSplitInfo.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.nodes.ContinuousSplitNode;
-import org.apache.ignite.ml.trees.nodes.SplitNode;
-
-/**
- * Information about split of continuous region.
- *
- * @param <D> Class encapsulating information about the region.
- */
-public class ContinuousSplitInfo<D extends RegionInfo> extends SplitInfo<D> {
- /**
- * Threshold used for split.
- * Samples with values less or equal than this go to left region, others go to the right region.
- */
- private final double threshold;
-
- /**
- * @param regionIdx Index of region being split.
- * @param threshold Threshold used for split. Samples with values less or equal than this go to left region, others
- * go to the right region.
- * @param leftData Information about left subregion.
- * @param rightData Information about right subregion.
- */
- public ContinuousSplitInfo(int regionIdx, double threshold, D leftData, D rightData) {
- super(regionIdx, leftData, rightData);
- this.threshold = threshold;
- }
-
- /** {@inheritDoc} */
- @Override public SplitNode createSplitNode(int featureIdx) {
- return new ContinuousSplitNode(threshold, featureIdx);
- }
-
- /**
- * Threshold used for splits.
- * Samples with values less or equal than this go to left region, others go to the right region.
- */
- public double threshold() {
- return threshold;
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- return "ContinuousSplitInfo [" +
- "threshold=" + threshold +
- ", infoGain=" + infoGain +
- ", regionIdx=" + regionIdx +
- ", leftData=" + leftData +
- ", rightData=" + rightData +
- ']';
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java
deleted file mode 100644
index 56508e5..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureProcessor.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection;
-
-/**
- * Base interface for feature processors used in {@link ColumnDecisionTreeTrainer}
- *
- * @param <D> Class representing data of regions resulted from split.
- * @param <S> Class representing data of split.
- */
-public interface FeatureProcessor<D extends RegionInfo, S extends SplitInfo<D>> {
- /**
- * Finds best split by this feature among all splits of all regions.
- *
- * @return best split by this feature among all splits of all regions.
- */
- SplitInfo findBestSplit(RegionProjection<D> regionPrj, double[] values, double[] labels, int regIdx);
-
- /**
- * Creates initial region from samples.
- *
- * @param samples samples.
- * @return region.
- */
- RegionProjection<D> createInitialRegion(Integer[] samples, double[] values, double[] labels);
-
- /**
- * Calculates the bitset mapping each data point to left (corresponding bit is set) or right subregion.
- *
- * @param s data used for calculating the split.
- * @return Bitset mapping each data point to left (corresponding bit is set) or right subregion.
- */
- SparseBitSet calculateOwnershipBitSet(RegionProjection<D> regionPrj, double[] values, S s);
-
- /**
- * Splits given region using bitset which maps data point to left or right subregion.
- * This method is present for the vectors of the same type to be able to pass between them information about regions
- * and therefore used iff the optimal split is received on feature of the same type.
- *
- * @param bs Bitset which maps data point to left or right subregion.
- * @param leftData Data of the left subregion.
- * @param rightData Data of the right subregion.
- * @return This feature vector.
- */
- IgniteBiTuple<RegionProjection, RegionProjection> performSplit(SparseBitSet bs, RegionProjection<D> reg, D leftData,
- D rightData);
-
- /**
- * Splits given region using bitset which maps data point to left or right subregion. This method is used iff the
- * optimal split is received on feature of different type, therefore information about regions is limited to the
- * {@link RegionInfo} class which is base for all classes used to represent region data.
- *
- * @param bs Bitset which maps data point to left or right subregion.
- * @param leftData Data of the left subregion.
- * @param rightData Data of the right subregion.
- * @return This feature vector.
- */
- IgniteBiTuple<RegionProjection, RegionProjection> performSplitGeneric(SparseBitSet bs, double[] values,
- RegionProjection<D> reg, RegionInfo leftData,
- RegionInfo rightData);
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java
deleted file mode 100644
index 69ff019..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/FeatureVectorProcessorUtils.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import org.apache.ignite.lang.IgniteBiTuple;
-
-/** Utility class for feature vector processors. */
-public class FeatureVectorProcessorUtils {
- /**
- * Split target array into two (left and right) arrays by bitset.
- *
- * @param lSize Left array size;
- * @param rSize Right array size.
- * @param samples Arrays to split size.
- * @param bs Bitset specifying split.
- * @return BiTuple containing result of split.
- */
- public static IgniteBiTuple<Integer[], Integer[]> splitByBitSet(int lSize, int rSize, Integer[] samples,
- SparseBitSet bs) {
- Integer[] lArr = new Integer[lSize];
- Integer[] rArr = new Integer[rSize];
-
- int lc = 0;
- int rc = 0;
-
- for (int i = 0; i < lSize + rSize; i++) {
- int si = samples[i];
-
- if (bs.get(si)) {
- lArr[lc] = si;
- lc++;
- }
- else {
- rArr[rc] = si;
- rc++;
- }
- }
-
- return new IgniteBiTuple<>(lArr, rArr);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java
deleted file mode 100644
index 8aa4f79..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SampleInfo.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-
-/**
- * Information about given sample within given fixed feature.
- */
-public class SampleInfo implements Externalizable {
- /** Value of projection of this sample on given fixed feature. */
- private double val;
-
- /** Sample index. */
- private int sampleIdx;
-
- /**
- * @param val Value of projection of this sample on given fixed feature.
- * @param sampleIdx Sample index.
- */
- public SampleInfo(double val, int sampleIdx) {
- this.val = val;
- this.sampleIdx = sampleIdx;
- }
-
- /**
- * No-op constructor used for serialization/deserialization.
- */
- public SampleInfo() {
- // No-op.
- }
-
- /**
- * Get the value of projection of this sample on given fixed feature.
- *
- * @return Value of projection of this sample on given fixed feature.
- */
- public double val() {
- return val;
- }
-
- /**
- * Get the sample index.
- *
- * @return Sample index.
- */
- public int sampleInd() {
- return sampleIdx;
- }
-
- /** {@inheritDoc} */
- @Override public void writeExternal(ObjectOutput out) throws IOException {
- out.writeDouble(val);
- out.writeInt(sampleIdx);
- }
-
- /** {@inheritDoc} */
- @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- val = in.readDouble();
- sampleIdx = in.readInt();
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java
deleted file mode 100644
index 124e82f..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/SplitInfo.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
-
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.nodes.SplitNode;
-
-/**
- * Class encapsulating information about the split.
- *
- * @param <D> Class representing information of left and right subregions.
- */
-public abstract class SplitInfo<D extends RegionInfo> {
- /** Information gain of this split. */
- protected double infoGain;
-
- /** Index of the region to split. */
- protected final int regionIdx;
-
- /** Data of left subregion. */
- protected final D leftData;
-
- /** Data of right subregion. */
- protected final D rightData;
-
- /**
- * Construct the split info.
- *
- * @param regionIdx Index of the region to split.
- * @param leftData Data of left subregion.
- * @param rightData Data of right subregion.
- */
- public SplitInfo(int regionIdx, D leftData, D rightData) {
- this.regionIdx = regionIdx;
- this.leftData = leftData;
- this.rightData = rightData;
- }
-
- /**
- * Index of region to split.
- *
- * @return Index of region to split.
- */
- public int regionIndex() {
- return regionIdx;
- }
-
- /**
- * Information gain of the split.
- *
- * @return Information gain of the split.
- */
- public double infoGain() {
- return infoGain;
- }
-
- /**
- * Data of right subregion.
- *
- * @return Data of right subregion.
- */
- public D rightData() {
- return rightData;
- }
-
- /**
- * Data of left subregion.
- *
- * @return Data of left subregion.
- */
- public D leftData() {
- return leftData;
- }
-
- /**
- * Create SplitNode from this split info.
- *
- * @param featureIdx Index of feature by which goes split.
- * @return SplitNode from this split info.
- */
- public abstract SplitNode createSplitNode(int featureIdx);
-
- /**
- * Set information gain.
- *
- * @param infoGain Information gain.
- */
- public void setInfoGain(double infoGain) {
- this.infoGain = infoGain;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java
deleted file mode 100644
index 0dea204..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/vectors/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains feature containers needed by column based decision tree trainers.
- */
-package org.apache.ignite.ml.trees.trainers.columnbased.vectors;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index e22a3a5..9900f85 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -28,7 +28,7 @@ import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite;
import org.apache.ignite.ml.regressions.RegressionsTestSuite;
import org.apache.ignite.ml.svm.SVMTestSuite;
import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite;
-import org.apache.ignite.ml.trees.DecisionTreesTestSuite;
+import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -41,7 +41,7 @@ import org.junit.runners.Suite;
RegressionsTestSuite.class,
SVMTestSuite.class,
ClusteringTestSuite.class,
- DecisionTreesTestSuite.class,
+ DecisionTreeTestSuite.class,
KNNTestSuite.class,
LocalModelsTest.class,
MLPTestSuite.class,
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java
index e624004..d68b355 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistMLPTestUtil.java
@@ -25,11 +25,10 @@ import java.util.Random;
import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.trees.performance.ColumnDecisionTreeTrainerBenchmark;
import org.apache.ignite.ml.util.MnistUtils;
/** */
-class MnistMLPTestUtil {
+public class MnistMLPTestUtil {
/** Name of the property specifying path to training set images. */
private static final String PROP_TRAINING_IMAGES = "mnist.training.images";
@@ -62,7 +61,7 @@ class MnistMLPTestUtil {
* @return List of MNIST images.
* @throws IOException In case of exception.
*/
- static List<MnistUtils.MnistLabeledImage> loadTrainingSet(int cnt) throws IOException {
+ public static List<MnistUtils.MnistLabeledImage> loadTrainingSet(int cnt) throws IOException {
Properties props = loadMNISTProperties();
return MnistUtils.mnistAsList(props.getProperty(PROP_TRAINING_IMAGES), props.getProperty(PROP_TRAINING_LABELS), new Random(123L), cnt);
}
@@ -74,7 +73,7 @@ class MnistMLPTestUtil {
* @return List of MNIST images.
* @throws IOException In case of exception.
*/
- static List<MnistUtils.MnistLabeledImage> loadTestSet(int cnt) throws IOException {
+ public static List<MnistUtils.MnistLabeledImage> loadTestSet(int cnt) throws IOException {
Properties props = loadMNISTProperties();
return MnistUtils.mnistAsList(props.getProperty(PROP_TEST_IMAGES), props.getProperty(PROP_TEST_LABELS), new Random(123L), cnt);
}
@@ -83,7 +82,7 @@ class MnistMLPTestUtil {
private static Properties loadMNISTProperties() throws IOException {
Properties res = new Properties();
- InputStream is = ColumnDecisionTreeTrainerBenchmark.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties");
+ InputStream is = MnistMLPTestUtil.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties");
res.load(is);
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
new file mode 100644
index 0000000..94bca3f
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.Arrays;
+import java.util.Random;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link DecisionTreeClassificationTrainer} that require to start the whole Ignite infrastructure.
+ */
+public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+
+ /** Ignite instance. */
+ private Ignite ignite;
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ /* Grid instance. */
+ ignite = grid(NODE_COUNT);
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ }
+
+ /** */
+ public void testFit() {
+ int size = 100;
+
+ CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>();
+ trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+ trainingSetCacheCfg.setName("TRAINING_SET");
+
+ IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg);
+
+ Random rnd = new Random(0);
+ for (int i = 0; i < size; i++) {
+ double x = rnd.nextDouble() - 0.5;
+ data.put(i, new double[]{x, x > 0 ? 1 : 0});
+ }
+
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+
+ DecisionTreeNode tree = trainer.fit(
+ new CacheBasedDatasetBuilder<>(ignite, data),
+ (k, v) -> Arrays.copyOf(v, v.length - 1),
+ (k, v) -> v[v.length - 1]
+ );
+
+ assertTrue(tree instanceof DecisionTreeConditionalNode);
+
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+
+ assertEquals(0, node.getThreshold(), 1e-3);
+
+ assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
+ assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
+
+ DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
+ DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
+
+ assertEquals(1, thenNode.getVal(), 1e-10);
+ assertEquals(0, elseNode.getVal(), 1e-10);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
new file mode 100644
index 0000000..2599bfe
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+
+/**
+ * Tests for {@link DecisionTreeClassificationTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class DecisionTreeClassificationTrainerTest {
+ /** Number of parts to be tested. */
+ private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ @Parameterized.Parameters(name = "Data divided on {0} partitions")
+ public static Iterable<Integer[]> data() {
+ List<Integer[]> res = new ArrayList<>();
+ for (int part : partsToBeTested)
+ res.add(new Integer[] {part});
+
+ return res;
+ }
+
+ /** */
+ @Test
+ public void testFit() {
+ int size = 100;
+
+ Map<Integer, double[]> data = new HashMap<>();
+
+ Random rnd = new Random(0);
+ for (int i = 0; i < size; i++) {
+ double x = rnd.nextDouble() - 0.5;
+ data.put(i, new double[]{x, x > 0 ? 1 : 0});
+ }
+
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+
+ DecisionTreeNode tree = trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ (k, v) -> Arrays.copyOf(v, v.length - 1),
+ (k, v) -> v[v.length - 1]
+ );
+
+ assertTrue(tree instanceof DecisionTreeConditionalNode);
+
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+
+ assertEquals(0, node.getThreshold(), 1e-3);
+
+ assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
+ assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
+
+ DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
+ DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
+
+ assertEquals(1, thenNode.getVal(), 1e-10);
+ assertEquals(0, elseNode.getVal(), 1e-10);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
new file mode 100644
index 0000000..754ff20
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.Arrays;
+import java.util.Random;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link DecisionTreeRegressionTrainer} that require to start the whole Ignite infrastructure.
+ */
+public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbstractTest {
+ /** Number of nodes in grid */
+ private static final int NODE_COUNT = 3;
+
+ /** Ignite instance. */
+ private Ignite ignite;
+
+ /** {@inheritDoc} */
+ @Override protected void beforeTestsStarted() throws Exception {
+ for (int i = 1; i <= NODE_COUNT; i++)
+ startGrid(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected void afterTestsStopped() {
+ stopAllGrids();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override protected void beforeTest() throws Exception {
+ /* Grid instance. */
+ ignite = grid(NODE_COUNT);
+ ignite.configuration().setPeerClassLoadingEnabled(true);
+ IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+ }
+
+ /** */
+ public void testFit() {
+ int size = 100;
+
+ CacheConfiguration<Integer, double[]> trainingSetCacheCfg = new CacheConfiguration<>();
+ trainingSetCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+ trainingSetCacheCfg.setName("TRAINING_SET");
+
+ IgniteCache<Integer, double[]> data = ignite.createCache(trainingSetCacheCfg);
+
+ Random rnd = new Random(0);
+ for (int i = 0; i < size; i++) {
+ double x = rnd.nextDouble() - 0.5;
+ data.put(i, new double[]{x, x > 0 ? 1 : 0});
+ }
+
+ DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
+
+ DecisionTreeNode tree = trainer.fit(
+ new CacheBasedDatasetBuilder<>(ignite, data),
+ (k, v) -> Arrays.copyOf(v, v.length - 1),
+ (k, v) -> v[v.length - 1]
+ );
+
+ assertTrue(tree instanceof DecisionTreeConditionalNode);
+
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+
+ assertEquals(0, node.getThreshold(), 1e-3);
+
+ assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
+ assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
+
+ DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
+ DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
+
+ assertEquals(1, thenNode.getVal(), 1e-10);
+ assertEquals(0, elseNode.getVal(), 1e-10);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
new file mode 100644
index 0000000..3bdbf60
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static junit.framework.TestCase.assertEquals;
+import static junit.framework.TestCase.assertTrue;
+
+/**
+ * Tests for {@link DecisionTreeRegressionTrainer}.
+ */
+@RunWith(Parameterized.class)
+public class DecisionTreeRegressionTrainerTest {
+ /** Number of parts to be tested. */
+ private static final int[] partsToBeTested = new int[] {1, 2, 3, 4, 5, 7};
+
+ /** Number of partitions. */
+ @Parameterized.Parameter
+ public int parts;
+
+ @Parameterized.Parameters(name = "Data divided on {0} partitions")
+ public static Iterable<Integer[]> data() {
+ List<Integer[]> res = new ArrayList<>();
+ for (int part : partsToBeTested)
+ res.add(new Integer[] {part});
+
+ return res;
+ }
+
+ /** */
+ @Test
+ public void testFit() {
+ int size = 100;
+
+ Map<Integer, double[]> data = new HashMap<>();
+
+ Random rnd = new Random(0);
+ for (int i = 0; i < size; i++) {
+ double x = rnd.nextDouble() - 0.5;
+ data.put(i, new double[]{x, x > 0 ? 1 : 0});
+ }
+
+ DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
+
+ DecisionTreeNode tree = trainer.fit(
+ new LocalDatasetBuilder<>(data, parts),
+ (k, v) -> Arrays.copyOf(v, v.length - 1),
+ (k, v) -> v[v.length - 1]
+ );
+
+ assertTrue(tree instanceof DecisionTreeConditionalNode);
+
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+
+ assertEquals(0, node.getThreshold(), 1e-3);
+
+ assertTrue(node.getThenNode() instanceof DecisionTreeLeafNode);
+ assertTrue(node.getElseNode() instanceof DecisionTreeLeafNode);
+
+ DecisionTreeLeafNode thenNode = (DecisionTreeLeafNode) node.getThenNode();
+ DecisionTreeLeafNode elseNode = (DecisionTreeLeafNode) node.getElseNode();
+
+ assertEquals(1, thenNode.getVal(), 1e-10);
+ assertEquals(0, elseNode.getVal(), 1e-10);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java
new file mode 100644
index 0000000..2cbb486
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeTestSuite.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import org.apache.ignite.ml.tree.data.DecisionTreeDataTest;
+import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureCalculatorTest;
+import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureTest;
+import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureCalculatorTest;
+import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureTest;
+import org.apache.ignite.ml.tree.impurity.util.SimpleStepFunctionCompressorTest;
+import org.apache.ignite.ml.tree.impurity.util.StepFunctionTest;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for all tests located in {@link org.apache.ignite.ml.tree} package.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ DecisionTreeClassificationTrainerTest.class,
+ DecisionTreeRegressionTrainerTest.class,
+ DecisionTreeClassificationTrainerIntegrationTest.class,
+ DecisionTreeRegressionTrainerIntegrationTest.class,
+ DecisionTreeDataTest.class,
+ GiniImpurityMeasureCalculatorTest.class,
+ GiniImpurityMeasureTest.class,
+ MSEImpurityMeasureCalculatorTest.class,
+ MSEImpurityMeasureTest.class,
+ StepFunctionTest.class,
+ SimpleStepFunctionCompressorTest.class
+})
+public class DecisionTreeTestSuite {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
new file mode 100644
index 0000000..0c89d4e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/data/DecisionTreeDataTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.data;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link DecisionTreeData}.
+ */
+public class DecisionTreeDataTest {
+ /** */
+ @Test
+ public void testFilter() {
+ double[][] features = new double[][]{{0}, {1}, {2}, {3}, {4}, {5}};
+ double[] labels = new double[]{0, 1, 2, 3, 4, 5};
+
+ DecisionTreeData data = new DecisionTreeData(features, labels);
+ DecisionTreeData filteredData = data.filter(obj -> obj[0] > 2);
+
+ assertArrayEquals(new double[][]{{3}, {4}, {5}}, filteredData.getFeatures());
+ assertArrayEquals(new double[]{3, 4, 5}, filteredData.getLabels(), 1e-10);
+ }
+
+ /** */
+ @Test
+ public void testSort() {
+ double[][] features = new double[][]{{4, 1}, {3, 3}, {2, 0}, {1, 4}, {0, 2}};
+ double[] labels = new double[]{0, 1, 2, 3, 4};
+
+ DecisionTreeData data = new DecisionTreeData(features, labels);
+
+ data.sort(0);
+
+ assertArrayEquals(new double[][]{{0, 2}, {1, 4}, {2, 0}, {3, 3}, {4, 1}}, features);
+ assertArrayEquals(new double[]{4, 3, 2, 1, 0}, labels, 1e-10);
+
+ data.sort(1);
+
+ assertArrayEquals(new double[][]{{2, 0}, {4, 1}, {0, 2}, {3, 3}, {1, 4}}, features);
+ assertArrayEquals(new double[]{2, 0, 4, 1, 3}, labels, 1e-10);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
new file mode 100644
index 0000000..afd81e8
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculatorTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity.gini;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.util.StepFunction;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link GiniImpurityMeasureCalculator}.
+ */
+public class GiniImpurityMeasureCalculatorTest {
+ /** */
+ @Test
+ public void testCalculate() {
+ double[][] data = new double[][]{{0, 1}, {1, 0}, {2, 2}, {3, 3}};
+ double[] labels = new double[]{0, 1, 1, 1};
+
+ Map<Double, Integer> encoder = new HashMap<>();
+ encoder.put(0.0, 0);
+ encoder.put(1.0, 1);
+ GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+
+ StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+
+ assertEquals(2, impurity.length);
+
+ // Check Gini calculated for the first column.
+ assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10);
+ assertEquals(-2.500, impurity[0].getY()[0].impurity(), 1e-3);
+ assertEquals(-4.000, impurity[0].getY()[1].impurity(),1e-3);
+ assertEquals(-3.000, impurity[0].getY()[2].impurity(),1e-3);
+ assertEquals(-2.666, impurity[0].getY()[3].impurity(),1e-3);
+ assertEquals(-2.500, impurity[0].getY()[4].impurity(),1e-3);
+
+ // Check Gini calculated for the second column.
+ assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[1].getX(), 1e-10);
+ assertEquals(-2.500, impurity[1].getY()[0].impurity(),1e-3);
+ assertEquals(-2.666, impurity[1].getY()[1].impurity(),1e-3);
+ assertEquals(-3.000, impurity[1].getY()[2].impurity(),1e-3);
+ assertEquals(-2.666, impurity[1].getY()[3].impurity(),1e-3);
+ assertEquals(-2.500, impurity[1].getY()[4].impurity(),1e-3);
+ }
+
+ /** */
+ @Test
+ public void testCalculateWithRepeatedData() {
+ double[][] data = new double[][]{{0}, {1}, {2}, {2}, {3}};
+ double[] labels = new double[]{0, 1, 1, 1, 1};
+
+ Map<Double, Integer> encoder = new HashMap<>();
+ encoder.put(0.0, 0);
+ encoder.put(1.0, 1);
+ GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+
+ StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+
+ assertEquals(1, impurity.length);
+
+ // Check Gini calculated for the first column.
+ assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10);
+ assertEquals(-3.400, impurity[0].getY()[0].impurity(), 1e-3);
+ assertEquals(-5.000, impurity[0].getY()[1].impurity(),1e-3);
+ assertEquals(-4.000, impurity[0].getY()[2].impurity(),1e-3);
+ assertEquals(-3.500, impurity[0].getY()[3].impurity(),1e-3);
+ assertEquals(-3.400, impurity[0].getY()[4].impurity(),1e-3);
+ }
+
+ /** */
+ @Test
+ public void testGetLabelCode() {
+ Map<Double, Integer> encoder = new HashMap<>();
+ encoder.put(0.0, 0);
+ encoder.put(1.0, 1);
+ encoder.put(2.0, 2);
+
+ GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder);
+
+ assertEquals(0, calculator.getLabelCode(0.0));
+ assertEquals(1, calculator.getLabelCode(1.0));
+ assertEquals(2, calculator.getLabelCode(2.0));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java
new file mode 100644
index 0000000..35c456a
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity.gini;
+
+import java.util.Random;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Tests for {@link GiniImpurityMeasure}.
+ */
+public class GiniImpurityMeasureTest {
+ /** */
+ @Test
+ public void testImpurityOnEmptyData() {
+ long[] left = new long[]{0, 0, 0};
+ long[] right = new long[]{0, 0, 0};
+
+ GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right);
+
+ assertEquals(0.0, impurity.impurity(), 1e-10);
+ }
+
+ /** */
+ @Test
+ public void testImpurityLeftPart() {
+ long[] left = new long[]{3, 0, 0};
+ long[] right = new long[]{0, 0, 0};
+
+ GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right);
+
+ assertEquals(-3, impurity.impurity(), 1e-10);
+ }
+
+ /** */
+ @Test
+ public void testImpurityRightPart() {
+ long[] left = new long[]{0, 0, 0};
+ long[] right = new long[]{3, 0, 0};
+
+ GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right);
+
+ assertEquals(-3, impurity.impurity(), 1e-10);
+ }
+
+ /** */
+ @Test
+ public void testImpurityLeftAndRightPart() {
+ long[] left = new long[]{3, 0, 0};
+ long[] right = new long[]{0, 3, 0};
+
+ GiniImpurityMeasure impurity = new GiniImpurityMeasure(left, right);
+
+ assertEquals(-6, impurity.impurity(), 1e-10);
+ }
+
+ /** */
+ @Test
+ public void testAdd() {
+ Random rnd = new Random(0);
+
+ GiniImpurityMeasure a = new GiniImpurityMeasure(
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)},
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}
+ );
+
+
+ GiniImpurityMeasure b = new GiniImpurityMeasure(
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)},
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}
+ );
+
+ GiniImpurityMeasure c = a.add(b);
+
+ assertEquals(a.getLeft()[0] + b.getLeft()[0], c.getLeft()[0]);
+ assertEquals(a.getLeft()[1] + b.getLeft()[1], c.getLeft()[1]);
+ assertEquals(a.getLeft()[2] + b.getLeft()[2], c.getLeft()[2]);
+
+ assertEquals(a.getRight()[0] + b.getRight()[0], c.getRight()[0]);
+ assertEquals(a.getRight()[1] + b.getRight()[1], c.getRight()[1]);
+ assertEquals(a.getRight()[2] + b.getRight()[2], c.getRight()[2]);
+ }
+
+ /** */
+ @Test
+ public void testSubtract() {
+ Random rnd = new Random(0);
+
+ GiniImpurityMeasure a = new GiniImpurityMeasure(
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)},
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}
+ );
+
+
+ GiniImpurityMeasure b = new GiniImpurityMeasure(
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)},
+ new long[]{randCnt(rnd), randCnt(rnd), randCnt(rnd)}
+ );
+
+ GiniImpurityMeasure c = a.subtract(b);
+
+ assertEquals(a.getLeft()[0] - b.getLeft()[0], c.getLeft()[0]);
+ assertEquals(a.getLeft()[1] - b.getLeft()[1], c.getLeft()[1]);
+ assertEquals(a.getLeft()[2] - b.getLeft()[2], c.getLeft()[2]);
+
+ assertEquals(a.getRight()[0] - b.getRight()[0], c.getRight()[0]);
+ assertEquals(a.getRight()[1] - b.getRight()[1], c.getRight()[1]);
+ assertEquals(a.getRight()[2] - b.getRight()[2], c.getRight()[2]);
+ }
+
+ /** Generates random count. */
+ private long randCnt(Random rnd) {
+ return Math.abs(rnd.nextInt());
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
new file mode 100644
index 0000000..510c18f
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/impurity/mse/MSEImpurityMeasureCalculatorTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity.mse;
+
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.util.StepFunction;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+
+/**
+ * Tests for {@link MSEImpurityMeasureCalculator}.
+ */
+public class MSEImpurityMeasureCalculatorTest {
+ /** */
+ @Test
+ public void testCalculate() {
+ double[][] data = new double[][]{{0, 2}, {1, 1}, {2, 0}, {3, 3}};
+ double[] labels = new double[]{1, 2, 2, 1};
+
+ MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator();
+
+ StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels));
+
+ assertEquals(2, impurity.length);
+
+ // Test MSE calculated for the first column.
+ assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[0].getX(), 1e-10);
+ assertEquals(1.000, impurity[0].getY()[0].impurity(), 1e-3);
+ assertEquals(0.666, impurity[0].getY()[1].impurity(),1e-3);
+ assertEquals(1.000, impurity[0].getY()[2].impurity(),1e-3);
+ assertEquals(0.666, impurity[0].getY()[3].impurity(),1e-3);
+ assertEquals(1.000, impurity[0].getY()[4].impurity(),1e-3);
+
+ // Test MSE calculated for the second column.
+ assertArrayEquals(new double[]{Double.NEGATIVE_INFINITY, 0, 1, 2, 3}, impurity[1].getX(), 1e-10);
+ assertEquals(1.000, impurity[1].getY()[0].impurity(),1e-3);
+ assertEquals(0.666, impurity[1].getY()[1].impurity(),1e-3);
+ assertEquals(0.000, impurity[1].getY()[2].impurity(),1e-3);
+ assertEquals(0.666, impurity[1].getY()[3].impurity(),1e-3);
+ assertEquals(1.000, impurity[1].getY()[4].impurity(),1e-3);
+ }
+}