You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ag...@apache.org on 2018/04/13 09:33:30 UTC
[18/54] [abbrv] ignite git commit: IGNITE-8059: Integrate decision
tree with partition based dataset.
IGNITE-8059: Integrate decision tree with partition based dataset.
this closes #3760
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/139c2af6
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/139c2af6
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/139c2af6
Branch: refs/heads/ignite-6083
Commit: 139c2af66a9f745f89429842810f5d5fe1addf28
Parents: a64b941
Author: dmitrievanthony <dm...@gmail.com>
Authored: Tue Apr 10 12:46:43 2018 +0300
Committer: YuriBabak <y....@gmail.com>
Committed: Tue Apr 10 12:46:44 2018 +0300
----------------------------------------------------------------------
...ecisionTreeClassificationTrainerExample.java | 147 +++++
.../DecisionTreeRegressionTrainerExample.java | 124 ++++
.../ignite/examples/ml/tree/package-info.java | 22 +
.../examples/ml/trees/DecisionTreesExample.java | 354 ------------
.../ignite/examples/ml/trees/package-info.java | 22 -
.../main/java/org/apache/ignite/ml/Trainer.java | 3 -
.../org/apache/ignite/ml/tree/DecisionTree.java | 252 ++++++++
.../tree/DecisionTreeClassificationTrainer.java | 93 +++
.../ml/tree/DecisionTreeConditionalNode.java | 78 +++
.../ignite/ml/tree/DecisionTreeLeafNode.java | 48 ++
.../apache/ignite/ml/tree/DecisionTreeNode.java | 26 +
.../ml/tree/DecisionTreeRegressionTrainer.java | 60 ++
.../org/apache/ignite/ml/tree/TreeFilter.java | 38 ++
.../ignite/ml/tree/data/DecisionTreeData.java | 128 +++++
.../ml/tree/data/DecisionTreeDataBuilder.java | 73 +++
.../ignite/ml/tree/data/package-info.java | 22 +
.../ml/tree/impurity/ImpurityMeasure.java | 55 ++
.../impurity/ImpurityMeasureCalculator.java | 38 ++
.../tree/impurity/gini/GiniImpurityMeasure.java | 115 ++++
.../gini/GiniImpurityMeasureCalculator.java | 110 ++++
.../ml/tree/impurity/gini/package-info.java | 22 +
.../tree/impurity/mse/MSEImpurityMeasure.java | 133 +++++
.../mse/MSEImpurityMeasureCalculator.java | 80 +++
.../ml/tree/impurity/mse/package-info.java | 22 +
.../ignite/ml/tree/impurity/package-info.java | 22 +
.../util/SimpleStepFunctionCompressor.java | 149 +++++
.../ml/tree/impurity/util/StepFunction.java | 162 ++++++
.../impurity/util/StepFunctionCompressor.java | 55 ++
.../ml/tree/impurity/util/package-info.java | 22 +
.../ml/tree/leaf/DecisionTreeLeafBuilder.java | 38 ++
.../tree/leaf/MeanDecisionTreeLeafBuilder.java | 73 +++
.../leaf/MostCommonDecisionTreeLeafBuilder.java | 86 +++
.../ignite/ml/tree/leaf/package-info.java | 22 +
.../org/apache/ignite/ml/tree/package-info.java | 22 +
.../ignite/ml/trees/CategoricalRegionInfo.java | 72 ---
.../ignite/ml/trees/CategoricalSplitInfo.java | 68 ---
.../ignite/ml/trees/ContinuousRegionInfo.java | 74 ---
.../ml/trees/ContinuousSplitCalculator.java | 51 --
.../org/apache/ignite/ml/trees/RegionInfo.java | 62 --
.../ml/trees/models/DecisionTreeModel.java | 44 --
.../ignite/ml/trees/models/package-info.java | 22 -
.../ml/trees/nodes/CategoricalSplitNode.java | 50 --
.../ml/trees/nodes/ContinuousSplitNode.java | 56 --
.../ignite/ml/trees/nodes/DecisionTreeNode.java | 33 --
.../org/apache/ignite/ml/trees/nodes/Leaf.java | 49 --
.../apache/ignite/ml/trees/nodes/SplitNode.java | 100 ----
.../ignite/ml/trees/nodes/package-info.java | 22 -
.../apache/ignite/ml/trees/package-info.java | 22 -
.../ml/trees/trainers/columnbased/BiIndex.java | 113 ----
...exedCacheColumnDecisionTreeTrainerInput.java | 57 --
.../CacheColumnDecisionTreeTrainerInput.java | 141 -----
.../columnbased/ColumnDecisionTreeTrainer.java | 568 -------------------
.../ColumnDecisionTreeTrainerInput.java | 55 --
.../MatrixColumnDecisionTreeTrainerInput.java | 83 ---
.../trainers/columnbased/RegionProjection.java | 109 ----
.../trainers/columnbased/TrainingContext.java | 166 ------
.../columnbased/caches/ContextCache.java | 68 ---
.../columnbased/caches/FeaturesCache.java | 151 -----
.../columnbased/caches/ProjectionsCache.java | 286 ----------
.../trainers/columnbased/caches/SplitCache.java | 206 -------
.../columnbased/caches/package-info.java | 22 -
.../ContinuousSplitCalculators.java | 34 --
.../contsplitcalcs/GiniSplitCalculator.java | 234 --------
.../contsplitcalcs/VarianceSplitCalculator.java | 179 ------
.../contsplitcalcs/package-info.java | 22 -
.../trainers/columnbased/package-info.java | 22 -
.../columnbased/regcalcs/RegionCalculators.java | 85 ---
.../columnbased/regcalcs/package-info.java | 22 -
.../vectors/CategoricalFeatureProcessor.java | 212 -------
.../vectors/ContinuousFeatureProcessor.java | 111 ----
.../vectors/ContinuousSplitInfo.java | 71 ---
.../columnbased/vectors/FeatureProcessor.java | 82 ---
.../vectors/FeatureVectorProcessorUtils.java | 57 --
.../columnbased/vectors/SampleInfo.java | 80 ---
.../trainers/columnbased/vectors/SplitInfo.java | 106 ----
.../columnbased/vectors/package-info.java | 22 -
.../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +-
.../ml/nn/performance/MnistMLPTestUtil.java | 9 +-
...reeClassificationTrainerIntegrationTest.java | 100 ++++
.../DecisionTreeClassificationTrainerTest.java | 91 +++
...ionTreeRegressionTrainerIntegrationTest.java | 100 ++++
.../tree/DecisionTreeRegressionTrainerTest.java | 91 +++
.../ignite/ml/tree/DecisionTreeTestSuite.java | 48 ++
.../ml/tree/data/DecisionTreeDataTest.java | 59 ++
.../gini/GiniImpurityMeasureCalculatorTest.java | 103 ++++
.../impurity/gini/GiniImpurityMeasureTest.java | 131 +++++
.../mse/MSEImpurityMeasureCalculatorTest.java | 59 ++
.../impurity/mse/MSEImpurityMeasureTest.java | 109 ++++
.../util/SimpleStepFunctionCompressorTest.java | 75 +++
.../ml/tree/impurity/util/StepFunctionTest.java | 71 +++
.../tree/impurity/util/TestImpurityMeasure.java | 88 +++
.../DecisionTreeMNISTIntegrationTest.java | 105 ++++
.../tree/performance/DecisionTreeMNISTTest.java | 74 +++
.../ignite/ml/trees/BaseDecisionTreeTest.java | 70 ---
.../ml/trees/ColumnDecisionTreeTrainerTest.java | 191 -------
.../ignite/ml/trees/DecisionTreesTestSuite.java | 33 --
.../ml/trees/GiniSplitCalculatorTest.java | 141 -----
.../ignite/ml/trees/SplitDataGenerator.java | 390 -------------
.../ml/trees/VarianceSplitCalculatorTest.java | 84 ---
.../ColumnDecisionTreeTrainerBenchmark.java | 456 ---------------
.../IgniteColumnDecisionTreeGiniBenchmark.java | 70 ---
...niteColumnDecisionTreeVarianceBenchmark.java | 71 ---
.../yardstick/ml/trees/SplitDataGenerator.java | 426 --------------
.../ignite/yardstick/ml/trees/package-info.java | 22 -
104 files changed, 3647 insertions(+), 6429 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
new file mode 100644
index 0000000..cef6368
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tree;
+
+import java.util.Random;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Example of using distributed {@link DecisionTreeClassificationTrainer}.
+ */
+public class DecisionTreeClassificationTrainerExample {
+ /**
+ * Executes example.
+ *
+ * @param args Command line arguments, none required.
+ */
+ public static void main(String... args) throws InterruptedException {
+ System.out.println(">>> Decision tree classification trainer example started.");
+
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> {
+
+ // Create cache with training data.
+ CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ trainingSetCfg.setName("TRAINING_SET");
+ trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+
+ Random rnd = new Random(0);
+
+ // Fill training data.
+ for (int i = 0; i < 1000; i++)
+ trainingSet.put(i, generatePoint(rnd));
+
+ // Create classification trainer.
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+ // Train decision tree model.
+ DecisionTreeNode mdl = trainer.fit(
+ new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ (k, v) -> new double[]{v.x, v.y},
+ (k, v) -> v.lb
+ );
+
+ // Calculate score.
+ int correctPredictions = 0;
+ for (int i = 0; i < 1000; i++) {
+ LabeledPoint pnt = generatePoint(rnd);
+
+ double prediction = mdl.apply(new double[]{pnt.x, pnt.y});
+
+ if (prediction == pnt.lb)
+ correctPredictions++;
+ }
+
+ System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
+
+ System.out.println(">>> Decision tree classification trainer example completed.");
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+
+ /**
+ * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label
+ * is 1, otherwise 0.
+ *
+ * @param rnd Random.
+ * @return Point with label.
+ */
+ private static LabeledPoint generatePoint(Random rnd) {
+
+ double x = rnd.nextDouble() - 0.5;
+ double y = rnd.nextDouble() - 0.5;
+
+ return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
+ }
+
+ /** Point data class. */
+ private static class Point {
+ /** X coordinate. */
+ final double x;
+
+ /** Y coordinate. */
+ final double y;
+
+ /**
+ * Constructs a new instance of point.
+ *
+ * @param x X coordinate.
+ * @param y Y coordinate.
+ */
+ Point(double x, double y) {
+ this.x = x;
+ this.y = y;
+ }
+ }
+
+ /** Labeled point data class. */
+ private static class LabeledPoint extends Point {
+ /** Point label. */
+ final double lb;
+
+ /**
+ * Constructs a new instance of labeled point data.
+ *
+ * @param x X coordinate.
+ * @param y Y coordinate.
+ * @param lb Point label.
+ */
+ LabeledPoint(double x, double y, double lb) {
+ super(x, y);
+ this.lb = lb;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
new file mode 100644
index 0000000..61ba5f9
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tree;
+
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Example of using distributed {@link DecisionTreeRegressionTrainer}.
+ */
+public class DecisionTreeRegressionTrainerExample {
+ /**
+ * Executes example.
+ *
+ * @param args Command line arguments, none required.
+ */
+ public static void main(String... args) throws InterruptedException {
+ System.out.println(">>> Decision tree regression trainer example started.");
+
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+ DecisionTreeRegressionTrainerExample.class.getSimpleName(), () -> {
+
+ // Create cache with training data.
+ CacheConfiguration<Integer, Point> trainingSetCfg = new CacheConfiguration<>();
+ trainingSetCfg.setName("TRAINING_SET");
+ trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, Point> trainingSet = ignite.createCache(trainingSetCfg);
+
+ // Fill training data.
+ generatePoints(trainingSet);
+
+ // Create regression trainer.
+ DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
+
+ // Train decision tree model.
+ DecisionTreeNode mdl = trainer.fit(
+ new CacheBasedDatasetBuilder<>(ignite, trainingSet),
+ (k, v) -> new double[] {v.x},
+ (k, v) -> v.y
+ );
+
+ System.out.println(">>> Linear regression model: " + mdl);
+
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
+
+ // Calculate score.
+ for (int x = 0; x < 10; x++) {
+ double predicted = mdl.apply(new double[] {x});
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
+ }
+
+ System.out.println(">>> ---------------------------------");
+
+ System.out.println(">>> Decision tree regression trainer example completed.");
+ });
+
+ igniteThread.start();
+
+ igniteThread.join();
+ }
+ }
+
+ /**
+ * Generates {@code sin(x)} on interval [0, 10) and loads into the specified cache.
+ */
+ private static void generatePoints(IgniteCache<Integer, Point> trainingSet) {
+ for (int i = 0; i < 1000; i++) {
+ double x = i / 100.0;
+ double y = Math.sin(x);
+
+ trainingSet.put(i, new Point(x, y));
+ }
+ }
+
+ /** Point data class. */
+ private static class Point {
+ /** X coordinate. */
+ final double x;
+
+ /** Y coordinate. */
+ final double y;
+
+ /**
+ * Constructs a new instance of point.
+ *
+ * @param x X coordinate.
+ * @param y Y coordinate.
+ */
+ Point(double x, double y) {
+ this.x = x;
+ this.y = y;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java
new file mode 100644
index 0000000..d8d9de6
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Decision trees examples.
+ */
+package org.apache.ignite.examples.ml.tree;
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java
deleted file mode 100644
index b1b2c42..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/trees/DecisionTreesExample.java
+++ /dev/null
@@ -1,354 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.examples.ml.trees;
-
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.net.URL;
-import java.nio.channels.Channels;
-import java.nio.channels.ReadableByteChannel;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Random;
-import java.util.Scanner;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import java.util.zip.GZIPInputStream;
-import org.apache.commons.cli.BasicParser;
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.CommandLineParser;
-import org.apache.commons.cli.Option;
-import org.apache.commons.cli.OptionBuilder;
-import org.apache.commons.cli.Options;
-import org.apache.commons.cli.ParseException;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteDataStreamer;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.examples.ExampleNodeStartup;
-import org.apache.ignite.examples.ml.MLExamplesCommonArgs;
-import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.estimators.Estimators;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.functions.IgniteTriFunction;
-import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
-import org.apache.ignite.ml.trees.models.DecisionTreeModel;
-import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex;
-import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators;
-import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator;
-import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators;
-import org.apache.ignite.ml.util.MnistUtils;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * <p>
- * Example of usage of decision trees algorithm for MNIST dataset
- * (it can be found here: http://yann.lecun.com/exdb/mnist/). </p>
- * <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
- * <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
- * <p>
- * It is recommended to start at least one node prior to launching this example if you intend
- * to run it with default memory settings.</p>
- * <p>
- * This example should be run with program arguments, for example
- * -cfg examples/config/example-ignite.xml.</p>
- * <p>
- * -cfg specifies path to a config path.</p>
- */
-public class DecisionTreesExample {
- /** Name of parameter specifying path of Ignite config. */
- private static final String CONFIG = "cfg";
-
- /** Default config path. */
- private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml";
-
- /**
- * Folder in which MNIST dataset is expected.
- */
- private static String MNIST_DIR = "examples/src/main/resources/";
-
- /**
- * Key for MNIST training images.
- */
- private static String MNIST_TRAIN_IMAGES = "train_images";
-
- /**
- * Key for MNIST training labels.
- */
- private static String MNIST_TRAIN_LABELS = "train_labels";
-
- /**
- * Key for MNIST test images.
- */
- private static String MNIST_TEST_IMAGES = "test_images";
-
- /**
- * Key for MNIST test labels.
- */
- private static String MNIST_TEST_LABELS = "test_labels";
-
- /**
- * Launches example.
- *
- * @param args Program arguments.
- */
- public static void main(String[] args) throws IOException {
- System.out.println(">>> Decision trees example started.");
-
- String igniteCfgPath;
-
- CommandLineParser parser = new BasicParser();
-
- String trainingImagesPath;
- String trainingLabelsPath;
-
- String testImagesPath;
- String testLabelsPath;
-
- Map<String, String> mnistPaths = new HashMap<>();
-
- mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte");
- mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte");
- mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte");
- mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte");
-
- try {
- // Parse the command line arguments.
- CommandLine line = parser.parse(buildOptions(), args);
-
- if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) {
- System.out.println(">>> Skipped example execution because 'unattended' mode is used.");
- System.out.println(">>> Decision trees example finished.");
- return;
- }
-
- igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
- }
- catch (ParseException e) {
- e.printStackTrace();
- return;
- }
-
- if (!getMNIST(mnistPaths.values())) {
- System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example.");
- return;
- }
-
- trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
- mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath();
- trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
- mnistPaths.get(MNIST_TRAIN_LABELS))).getPath();
- testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
- mnistPaths.get(MNIST_TEST_IMAGES))).getPath();
- testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" +
- mnistPaths.get(MNIST_TEST_LABELS))).getPath();
-
- try (Ignite ignite = Ignition.start(igniteCfgPath)) {
- IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
-
- int ptsCnt = 60000;
- int featCnt = 28 * 28;
-
- Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnistAsStream(trainingImagesPath, trainingLabelsPath,
- new Random(123L), ptsCnt);
-
- Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnistAsStream(testImagesPath, testLabelsPath,
- new Random(123L), 10_000);
-
- IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
-
- loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
-
- ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10,
- ContinuousSplitCalculators.GINI.apply(ignite),
- RegionCalculators.GINI,
- RegionCalculators.MOST_COMMON,
- ignite);
-
- System.out.println(">>> Training started");
- long before = System.currentTimeMillis();
- DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
- System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
-
- IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse =
- Estimators.errorsPercentage();
-
- Double accuracy = mse.apply(mdl, testMnistStream.map(v ->
- new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
-
- System.out.println(">>> Errs percentage: " + accuracy);
- }
- catch (IOException e) {
- e.printStackTrace();
- }
-
- System.out.println(">>> Decision trees example finished.");
- }
-
- /**
- * Get MNIST dataset. Value of predicate 'MNIST dataset is present in expected folder' is returned.
- *
- * @param mnistFileNames File names of MNIST dataset.
- * @return Value of predicate 'MNIST dataset is present in expected folder'.
- * @throws IOException In case of file system errors.
- */
- private static boolean getMNIST(Collection<String> mnistFileNames) throws IOException {
- List<String> missing = mnistFileNames.stream().
- filter(f -> IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f) == null).
- collect(Collectors.toList());
-
- if (!missing.isEmpty()) {
- System.out.println(">>> You have not fully downloaded MNIST dataset in directory " + MNIST_DIR +
- ", do you want it to be downloaded? [y]/n");
- Scanner s = new Scanner(System.in);
- String str = s.nextLine();
-
- if (!str.isEmpty() && !str.toLowerCase().equals("y"))
- return false;
- }
-
- for (String s : missing) {
- String f = s + ".gz";
- System.out.println(">>> Downloading " + f + "...");
- URL website = new URL("http://yann.lecun.com/exdb/mnistAsStream/" + f);
- ReadableByteChannel rbc = Channels.newChannel(website.openStream());
- FileOutputStream fos = new FileOutputStream(MNIST_DIR + "/" + f);
- fos.getChannel().transferFrom(rbc, 0, Long.MAX_VALUE);
- System.out.println(">>> Done.");
-
- System.out.println(">>> Unzipping " + f + "...");
- unzip(MNIST_DIR + "/" + f, MNIST_DIR + "/" + s);
-
- System.out.println(">>> Deleting gzip " + f + ", status: " +
- Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + f)).delete());
-
- System.out.println(">>> Done.");
- }
-
- return true;
- }
-
- /**
- * Unzip file located in {@code input} to {@code output}.
- *
- * @param input Input file path.
- * @param output Output file path.
- * @throws IOException In case of file system errors.
- */
- private static void unzip(String input, String output) throws IOException {
- byte[] buf = new byte[1024];
-
- try (GZIPInputStream gis = new GZIPInputStream(new FileInputStream(input));
- FileOutputStream out = new FileOutputStream(output)) {
- int sz;
- while ((sz = gis.read(buf)) > 0)
- out.write(buf, 0, sz);
- }
- }
-
- /**
- * Build cli options.
- */
- @NotNull private static Options buildOptions() {
- Options options = new Options();
-
- Option cfgOpt = OptionBuilder
- .withArgName(CONFIG)
- .withLongOpt(CONFIG)
- .hasArg()
- .withDescription("Path to the config.")
- .isRequired(false).create();
-
- Option unattended = OptionBuilder
- .withArgName(MLExamplesCommonArgs.UNATTENDED)
- .withLongOpt(MLExamplesCommonArgs.UNATTENDED)
- .withDescription("Is example run unattended.")
- .isRequired(false).create();
-
- options.addOption(cfgOpt);
- options.addOption(unattended);
-
- return options;
- }
-
- /**
- * Creates cache where data for training is stored.
- *
- * @param ignite Ignite instance.
- * @return cache where data for training is stored.
- */
- private static IgniteCache<BiIndex, Double> createBiIndexedCache(Ignite ignite) {
- CacheConfiguration<BiIndex, Double> cfg = new CacheConfiguration<>();
-
- // Write to primary.
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
- // No copying of values.
- cfg.setCopyOnRead(false);
-
- cfg.setName("TMP_BI_INDEXED_CACHE");
-
- return ignite.getOrCreateCache(cfg);
- }
-
- /**
- * Loads vectors into cache.
- *
- * @param cacheName Name of cache.
- * @param vectorsIter Iterator over vectors to load.
- * @param vectorSize Size of vector.
- * @param ignite Ignite instance.
- */
- private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator<? extends Vector> vectorsIter,
- int vectorSize, Ignite ignite) {
- try (IgniteDataStreamer<BiIndex, Double> streamer =
- ignite.dataStreamer(cacheName)) {
- int sampleIdx = 0;
-
- streamer.perNodeBufferSize(10000);
-
- while (vectorsIter.hasNext()) {
- org.apache.ignite.ml.math.Vector next = vectorsIter.next();
-
- for (int i = 0; i < vectorSize; i++)
- streamer.addData(new BiIndex(sampleIdx, i), next.getX(i));
-
- sampleIdx++;
-
- if (sampleIdx % 1000 == 0)
- System.out.println(">>> Loaded " + sampleIdx + " vectors.");
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java
deleted file mode 100644
index d944f60..0000000
--- a/examples/src/main/java/org/apache/ignite/examples/ml/trees/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Decision trees examples.
- */
-package org.apache.ignite.examples.ml.trees;
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java
index 4e0a570..f53b801 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/Trainer.java
@@ -17,11 +17,8 @@
package org.apache.ignite.ml;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-
/**
* Interface for Trainers. Trainer is just a function which produces model from the data.
- * See for example {@link ColumnDecisionTreeTrainer}.
*
* @param <M> Type of produced model.
* @param <T> Type of data needed for model producing.
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
new file mode 100644
index 0000000..c0b88fc
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.util.StepFunction;
+import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
+import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;
+
+/**
+ * Distributed decision tree trainer that allows to fit trees using row-partitioned dataset.
+ *
+ * @param <T> Type of impurity measure.
+ */
+abstract class DecisionTree<T extends ImpurityMeasure<T>> implements DatasetTrainer<DecisionTreeNode, Double> {
+ /** Max tree deep. */
+ private final int maxDeep;
+
+ /** Min impurity decrease. */
+ private final double minImpurityDecrease;
+
+ /** Step function compressor. */
+ private final StepFunctionCompressor<T> compressor;
+
+ /** Decision tree leaf builder. */
+ private final DecisionTreeLeafBuilder decisionTreeLeafBuilder;
+
+ /**
+ * Constructs a new distributed decision tree trainer.
+ *
+ * @param maxDeep Max tree deep.
+ * @param minImpurityDecrease Min impurity decrease.
+ * @param compressor Impurity function compressor.
+ * @param decisionTreeLeafBuilder Decision tree leaf builder.
+ */
+ DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor, DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
+ this.maxDeep = maxDeep;
+ this.minImpurityDecrease = minImpurityDecrease;
+ this.compressor = compressor;
+ this.decisionTreeLeafBuilder = decisionTreeLeafBuilder;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> DecisionTreeNode fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
+ new EmptyContextBuilder<>(),
+ new DecisionTreeDataBuilder<>(featureExtractor, lbExtractor)
+ )) {
+ return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Returns impurity measure calculator.
+ *
+ * @param dataset Dataset.
+ * @return Impurity measure calculator.
+ */
+ abstract ImpurityMeasureCalculator<T> getImpurityMeasureCalculator(Dataset<EmptyContext, DecisionTreeData> dataset);
+
+ /**
+ * Splits the node specified by the given dataset and predicate and returns decision tree node.
+ *
+ * @param dataset Dataset.
+ * @param filter Decision tree node predicate.
+ * @param deep Current tree deep.
+ * @param impurityCalc Impurity measure calculator.
+ * @return Decision tree node.
+ */
+ private DecisionTreeNode split(Dataset<EmptyContext, DecisionTreeData> dataset, TreeFilter filter, int deep,
+ ImpurityMeasureCalculator<T> impurityCalc) {
+ if (deep >= maxDeep)
+ return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
+
+ StepFunction<T>[] criterionFunctions = calculateImpurityForAllColumns(dataset, filter, impurityCalc);
+
+ if (criterionFunctions == null)
+ return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
+
+ SplitPoint splitPnt = calculateBestSplitPoint(criterionFunctions);
+
+ if (splitPnt == null)
+ return decisionTreeLeafBuilder.createLeafNode(dataset, filter);
+
+ return new DecisionTreeConditionalNode(
+ splitPnt.col,
+ splitPnt.threshold,
+ split(dataset, updatePredicateForThenNode(filter, splitPnt), deep + 1, impurityCalc),
+ split(dataset, updatePredicateForElseNode(filter, splitPnt), deep + 1, impurityCalc)
+ );
+ }
+
+ /**
+ * Calculates impurity measure functions for all columns for the node specified by the given dataset and predicate.
+ *
+ * @param dataset Dataset.
+ * @param filter Decision tree node predicate.
+ * @param impurityCalc Impurity measure calculator.
+ * @return Array of impurity measure functions for all columns.
+ */
+ private StepFunction<T>[] calculateImpurityForAllColumns(Dataset<EmptyContext, DecisionTreeData> dataset,
+ TreeFilter filter, ImpurityMeasureCalculator<T> impurityCalc) {
+ return dataset.compute(
+ part -> {
+ if (compressor != null)
+ return compressor.compress(impurityCalc.calculate(part.filter(filter)));
+ else
+ return impurityCalc.calculate(part.filter(filter));
+ }, this::reduce
+ );
+ }
+
+ /**
+ * Calculates best split point.
+ *
+ * @param criterionFunctions Array of impurity measure functions for all columns.
+ * @return Best split point.
+ */
+ private SplitPoint calculateBestSplitPoint(StepFunction<T>[] criterionFunctions) {
+ SplitPoint<T> res = null;
+
+ for (int col = 0; col < criterionFunctions.length; col++) {
+ StepFunction<T> criterionFunctionForCol = criterionFunctions[col];
+
+ double[] arguments = criterionFunctionForCol.getX();
+ T[] values = criterionFunctionForCol.getY();
+
+ for (int leftSize = 1; leftSize < values.length - 1; leftSize++) {
+ if ((values[0].impurity() - values[leftSize].impurity()) > minImpurityDecrease
+ && (res == null || values[leftSize].compareTo(res.val) < 0))
+ res = new SplitPoint<>(values[leftSize], col, calculateThreshold(arguments, leftSize));
+ }
+ }
+
+ return res;
+ }
+
+ /**
+ * Merges two arrays gotten from two partitions.
+ *
+ * @param a First step function.
+ * @param b Second step function.
+ * @return Merged step function.
+ */
+ private StepFunction<T>[] reduce(StepFunction<T>[] a, StepFunction<T>[] b) {
+ if (a == null)
+ return b;
+ if (b == null)
+ return a;
+ else {
+ StepFunction<T>[] res = Arrays.copyOf(a, a.length);
+
+ for (int i = 0; i < res.length; i++)
+ res[i] = res[i].add(b[i]);
+
+ return res;
+ }
+ }
+
+ /**
+ * Calculates threshold based on the given step function arguments and split point (specified left size).
+ *
+ * @param arguments Step function arguments.
+ * @param leftSize Split point (left size).
+ * @return Threshold.
+ */
+ private double calculateThreshold(double[] arguments, int leftSize) {
+ return (arguments[leftSize] + arguments[leftSize + 1]) / 2.0;
+ }
+
+ /**
+ * Constructs a new predicate for "then" node based on the parent node predicate and split point.
+ *
+ * @param filter Parent node predicate.
+ * @param splitPnt Split point.
+ * @return Predicate for "then" node.
+ */
+ private TreeFilter updatePredicateForThenNode(TreeFilter filter, SplitPoint splitPnt) {
+ return filter.and(f -> f[splitPnt.col] > splitPnt.threshold);
+ }
+
+ /**
+ * Constructs a new predicate for "else" node based on the parent node predicate and split point.
+ *
+ * @param filter Parent node predicate.
+ * @param splitPnt Split point.
+ * @return Predicate for "else" node.
+ */
+ private TreeFilter updatePredicateForElseNode(TreeFilter filter, SplitPoint splitPnt) {
+ return filter.and(f -> f[splitPnt.col] <= splitPnt.threshold);
+ }
+
+ /**
+ * Util class that represents split point.
+ */
+ private static class SplitPoint<T extends ImpurityMeasure<T>> implements Serializable {
+ /** */
+ private static final long serialVersionUID = -1758525953544425043L;
+
+ /** Split point impurity measure value. */
+ private final T val;
+
+ /** Column. */
+ private final int col;
+
+ /** Threshold. */
+ private final double threshold;
+
+ /**
+ * Constructs a new instance of split point.
+ *
+ * @param val Split point impurity measure value.
+ * @param col Column.
+ * @param threshold Threshold.
+ */
+ SplitPoint(T val, int col, double threshold) {
+ this.val = val;
+ this.col = col;
+ this.threshold = threshold;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
new file mode 100644
index 0000000..ce75190
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure;
+import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
+import org.apache.ignite.ml.tree.leaf.MostCommonDecisionTreeLeafBuilder;
+
+/**
+ * Decision tree classifier based on distributed decision tree trainer that allows to fit trees using row-partitioned
+ * dataset.
+ */
+public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurityMeasure> {
+ /**
+ * Constructs a new decision tree classifier with default impurity function compressor.
+ *
+ * @param maxDeep Max tree deep.
+ * @param minImpurityDecrease Min impurity decrease.
+ */
+ public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease) {
+ this(maxDeep, minImpurityDecrease, null);
+ }
+
+ /**
+ * Constructs a new instance of decision tree classifier.
+ *
+ * @param maxDeep Max tree deep.
+ * @param minImpurityDecrease Min impurity decrease.
+ */
+ public DecisionTreeClassificationTrainer(int maxDeep, double minImpurityDecrease,
+ StepFunctionCompressor<GiniImpurityMeasure> compressor) {
+ super(maxDeep, minImpurityDecrease, compressor, new MostCommonDecisionTreeLeafBuilder());
+ }
+
+ /** {@inheritDoc} */
+ @Override ImpurityMeasureCalculator<GiniImpurityMeasure> getImpurityMeasureCalculator(
+ Dataset<EmptyContext, DecisionTreeData> dataset) {
+ Set<Double> labels = dataset.compute(part -> {
+
+ if (part.getLabels() != null) {
+ Set<Double> list = new HashSet<>();
+
+ for (double lb : part.getLabels())
+ list.add(lb);
+
+ return list;
+ }
+
+ return null;
+ }, (a, b) -> {
+ if (a == null)
+ return b;
+ else if (b == null)
+ return a;
+ else {
+ a.addAll(b);
+ return a;
+ }
+ });
+
+ Map<Double, Integer> encoder = new HashMap<>();
+
+ int idx = 0;
+ for (Double lb : labels)
+ encoder.put(lb, idx++);
+
+ return new GiniImpurityMeasureCalculator(encoder);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
new file mode 100644
index 0000000..9818239
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+/**
+ * Decision tree conditional (non-leaf) node.
+ */
+public class DecisionTreeConditionalNode implements DecisionTreeNode {
+ /** */
+ private static final long serialVersionUID = 981630737007982172L;
+
+ /** Column of the value to be tested. */
+ private final int col;
+
+ /** Threshold. */
+ private final double threshold;
+
+ /** Node that will be used in case tested value is greater then threshold. */
+ private final DecisionTreeNode thenNode;
+
+ /** Node that will be used in case tested value is not greater then threshold. */
+ private final DecisionTreeNode elseNode;
+
+ /**
+ * Constructs a new instance of decision tree conditional node.
+ *
+ * @param col Column of the value to be tested.
+ * @param threshold Threshold.
+ * @param thenNode Node that will be used in case tested value is greater then threshold.
+ * @param elseNode Node that will be used in case tested value is not greater then threshold.
+ */
+ DecisionTreeConditionalNode(int col, double threshold, DecisionTreeNode thenNode, DecisionTreeNode elseNode) {
+ this.col = col;
+ this.threshold = threshold;
+ this.thenNode = thenNode;
+ this.elseNode = elseNode;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(double[] features) {
+ return features[col] > threshold ? thenNode.apply(features) : elseNode.apply(features);
+ }
+
+ /** */
+ public int getCol() {
+ return col;
+ }
+
+ /** */
+ public double getThreshold() {
+ return threshold;
+ }
+
+ /** */
+ public DecisionTreeNode getThenNode() {
+ return thenNode;
+ }
+
+ /** */
+ public DecisionTreeNode getElseNode() {
+ return elseNode;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
new file mode 100644
index 0000000..4c6369d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+/**
+ * Decision tree leaf node which contains value.
+ */
+public class DecisionTreeLeafNode implements DecisionTreeNode {
+ /** */
+ private static final long serialVersionUID = -472145568088482206L;
+
+ /** Value of the node. */
+ private final double val;
+
+ /**
+ * Constructs a new decision tree leaf node.
+ *
+ * @param val Value of the node.
+ */
+ public DecisionTreeLeafNode(double val) {
+ this.val = val;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(double[] doubles) {
+ return val;
+ }
+
+ /** */
+ public double getVal() {
+ return val;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
new file mode 100644
index 0000000..94878eb
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import org.apache.ignite.ml.Model;
+
+/**
+ * Base interface for decision tree nodes.
+ */
+public interface DecisionTreeNode extends Model<double[], Double> {
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
new file mode 100644
index 0000000..2bf09d3
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasure;
+import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.util.StepFunctionCompressor;
+import org.apache.ignite.ml.tree.leaf.MeanDecisionTreeLeafBuilder;
+
+/**
+ * Decision tree regressor based on distributed decision tree trainer that allows to fit trees using row-partitioned
+ * dataset.
+ */
+public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasure> {
+ /**
+ * Constructs a new decision tree regressor with default impurity function compressor.
+ *
+ * @param maxDeep Max tree deep.
+ * @param minImpurityDecrease Min impurity decrease.
+ */
+ public DecisionTreeRegressionTrainer(int maxDeep, double minImpurityDecrease) {
+ this(maxDeep, minImpurityDecrease, null);
+ }
+
+ /**
+ * Constructs a new decision tree regressor.
+ *
+ * @param maxDeep Max tree deep.
+ * @param minImpurityDecrease Min impurity decrease.
+ */
+ public DecisionTreeRegressionTrainer(int maxDeep, double minImpurityDecrease,
+ StepFunctionCompressor<MSEImpurityMeasure> compressor) {
+ super(maxDeep, minImpurityDecrease, compressor, new MeanDecisionTreeLeafBuilder());
+ }
+
+ /** {@inheritDoc} */
+ @Override ImpurityMeasureCalculator<MSEImpurityMeasure> getImpurityMeasureCalculator(
+ Dataset<EmptyContext, DecisionTreeData> dataset) {
+ return new MSEImpurityMeasureCalculator();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java
new file mode 100644
index 0000000..3e4dc00
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/TreeFilter.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.io.Serializable;
+import java.util.Objects;
+import java.util.function.Predicate;
+
+/**
+ * Predicate used to define objects that placed in decision tree node.
+ */
+public interface TreeFilter extends Predicate<double[]>, Serializable {
+ /**
+ * Returns a composed predicate.
+ *
+ * @param other Predicate that will be logically-ANDed with this predicate.
+ * @return Returns a composed predicate
+ */
+ default TreeFilter and(TreeFilter other) {
+ Objects.requireNonNull(other);
+ return (t) -> test(t) && other.test(t);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
new file mode 100644
index 0000000..34deb46
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeData.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.data;
+
+import org.apache.ignite.ml.tree.TreeFilter;
+
+/**
+ * A partition {@code data} of the containing matrix of features and vector of labels stored in heap.
+ */
+public class DecisionTreeData implements AutoCloseable {
+ /** Matrix with features. */
+ private final double[][] features;
+
+ /** Vector with labels. */
+ private final double[] labels;
+
+ /**
+ * Constructs a new instance of decision tree data.
+ *
+ * @param features Matrix with features.
+ * @param labels Vector with labels.
+ */
+ public DecisionTreeData(double[][] features, double[] labels) {
+ assert features.length == labels.length : "Features and labels have to be the same length";
+
+ this.features = features;
+ this.labels = labels;
+ }
+
+ /**
+ * Filters objects and returns only data that passed filter.
+ *
+ * @param filter Filter.
+ * @return Data passed filter.
+ */
+ public DecisionTreeData filter(TreeFilter filter) {
+ int size = 0;
+
+ for (int i = 0; i < features.length; i++)
+ if (filter.test(features[i]))
+ size++;
+
+ double[][] newFeatures = new double[size][];
+ double[] newLabels = new double[size];
+
+ int ptr = 0;
+
+ for (int i = 0; i < features.length; i++) {
+ if (filter.test(features[i])) {
+ newFeatures[ptr] = features[i];
+ newLabels[ptr] = labels[i];
+
+ ptr++;
+ }
+ }
+
+ return new DecisionTreeData(newFeatures, newLabels);
+ }
+
+ /**
+ * Sorts data by specified column in ascending order.
+ *
+ * @param col Column.
+ */
+ public void sort(int col) {
+ sort(col, 0, features.length - 1);
+ }
+
+ /** */
+ private void sort(int col, int from, int to) {
+ if (from < to) {
+ double pivot = features[(from + to) / 2][col];
+
+ int i = from, j = to;
+
+ while (i <= j) {
+ while (features[i][col] < pivot) i++;
+ while (features[j][col] > pivot) j--;
+
+ if (i <= j) {
+ double[] tmpFeature = features[i];
+ features[i] = features[j];
+ features[j] = tmpFeature;
+
+ double tmpLb = labels[i];
+ labels[i] = labels[j];
+ labels[j] = tmpLb;
+
+ i++;
+ j--;
+ }
+ }
+
+ sort(col, from, j);
+ sort(col, i, to);
+ }
+ }
+
+ /** */
+ public double[][] getFeatures() {
+ return features;
+ }
+
+ /** */
+ public double[] getLabels() {
+ return labels;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ // Do nothing, GC will clean up.
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
new file mode 100644
index 0000000..67109ae
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/DecisionTreeDataBuilder.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.data;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+
+/**
+ * A partition {@code data} builder that makes {@link DecisionTreeData}.
+ *
+ * @param <K> Type of a key in <tt>upstream</tt> data.
+ * @param <V> Type of a value in <tt>upstream</tt> data.
+ * @param <C> Type of a partition <tt>context</tt>.
+ */
+public class DecisionTreeDataBuilder<K, V, C extends Serializable>
+ implements PartitionDataBuilder<K, V, C, DecisionTreeData> {
+ /** */
+ private static final long serialVersionUID = 3678784980215216039L;
+
+ /** Function that extracts features from an {@code upstream} data. */
+ private final IgniteBiFunction<K, V, double[]> featureExtractor;
+
+ /** Function that extracts labels from an {@code upstream} data. */
+ private final IgniteBiFunction<K, V, Double> lbExtractor;
+
+ /**
+ * Constructs a new instance of decision tree data builder.
+ *
+ * @param featureExtractor Function that extracts features from an {@code upstream} data.
+ * @param lbExtractor Function that extracts labels from an {@code upstream} data.
+ */
+ public DecisionTreeDataBuilder(IgniteBiFunction<K, V, double[]> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+ this.featureExtractor = featureExtractor;
+ this.lbExtractor = lbExtractor;
+ }
+
+ /** {@inheritDoc} */
+ @Override public DecisionTreeData build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+ double[][] features = new double[Math.toIntExact(upstreamDataSize)][];
+ double[] labels = new double[Math.toIntExact(upstreamDataSize)];
+
+ int ptr = 0;
+ while (upstreamData.hasNext()) {
+ UpstreamEntry<K, V> entry = upstreamData.next();
+
+ features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue());
+ labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
+
+ ptr++;
+ }
+
+ return new DecisionTreeData(features, labels);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java
new file mode 100644
index 0000000..192b07f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/data/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains data and data builder required for decision tree trainers built on top of partition based dataset.
+ */
+package org.apache.ignite.ml.tree.data;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java
new file mode 100644
index 0000000..7ad2b80
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasure.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for impurity measures that can be used in distributed decision tree algorithm.
+ *
+ * @param <T> Type of this impurity measure.
+ */
+public interface ImpurityMeasure<T extends ImpurityMeasure<T>> extends Comparable<T>, Serializable {
+ /**
+ * Calculates impurity measure as a single double value.
+ *
+ * @return Impurity measure value.
+ */
+ public double impurity();
+
+ /**
+ * Adds the given impurity to this.
+ *
+ * @param measure Another impurity.
+ * @return Sum of this and the given impurity.
+ */
+ public T add(T measure);
+
+ /**
+ * Subtracts the given impurity for this.
+ *
+ * @param measure Another impurity.
+ * @return Difference of this and the given impurity.
+ */
+ public T subtract(T measure);
+
+ /** {@inheritDoc} */
+ default public int compareTo(T o) {
+ return Double.compare(impurity(), o.impurity());
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
new file mode 100644
index 0000000..2b69356
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/ImpurityMeasureCalculator.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.util.StepFunction;
+
+/**
+ * Base interface for impurity measure calculators that calculates all impurity measures required to find a best split.
+ *
+ * @param <T> Type of impurity measure.
+ */
+public interface ImpurityMeasureCalculator<T extends ImpurityMeasure<T>> extends Serializable {
+ /**
+ * Calculates all impurity measures required required to find a best split and returns them as an array of
+ * {@link StepFunction} (for every column).
+ *
+ * @param data Features and labels.
+ * @return Impurity measures as an array of {@link StepFunction} (for every column).
+ */
+ public StepFunction<T>[] calculate(DecisionTreeData data);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java
new file mode 100644
index 0000000..817baf5
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasure.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity.gini;
+
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
+
+/**
+ * Gini impurity measure which is calculated the following way:
+ * {@code \-frac{1}{L}\sum_{i=1}^{s}l_i^2 - \frac{1}{R}\sum_{i=s+1}^{n}r_i^2}.
+ */
+public class GiniImpurityMeasure implements ImpurityMeasure<GiniImpurityMeasure> {
+ /** */
+ private static final long serialVersionUID = 5338129703395229970L;
+
+ /** Number of elements of each type in the left part. */
+ private final long[] left;
+
+ /** Number of elements of each type in the right part. */
+ private final long[] right;
+
+ /**
+ * Constructs a new instance of Gini impurity measure.
+ *
+ * @param left Number of elements of each type in the left part.
+ * @param right Number of elements of each type in the right part.
+ */
+ GiniImpurityMeasure(long[] left, long[] right) {
+ assert left.length == right.length : "Left and right parts have to be the same length";
+
+ this.left = left;
+ this.right = right;
+ }
+
+ /** {@inheritDoc} */
+ @Override public double impurity() {
+ long leftCnt = 0;
+ long rightCnt = 0;
+
+ double leftImpurity = 0;
+ double rightImpurity = 0;
+
+ for (long e : left)
+ leftCnt += e;
+
+ for (long e : right)
+ rightCnt += e;
+
+ if (leftCnt > 0)
+ for (long e : left)
+ leftImpurity += Math.pow(e, 2) / leftCnt;
+
+ if (rightCnt > 0)
+ for (long e : right)
+ rightImpurity += Math.pow(e, 2) / rightCnt;
+
+ return -(leftImpurity + rightImpurity);
+ }
+
+ /** {@inheritDoc} */
+ @Override public GiniImpurityMeasure add(GiniImpurityMeasure b) {
+ assert left.length == b.left.length : "Subtracted measure has to have length " + left.length;
+ assert left.length == b.right.length : "Subtracted measure has to have length " + left.length;
+
+ long[] leftRes = new long[left.length];
+ long[] rightRes = new long[left.length];
+
+ for (int i = 0; i < left.length; i++) {
+ leftRes[i] = left[i] + b.left[i];
+ rightRes[i] = right[i] + b.right[i];
+ }
+
+ return new GiniImpurityMeasure(leftRes, rightRes);
+ }
+
+ /** {@inheritDoc} */
+ @Override public GiniImpurityMeasure subtract(GiniImpurityMeasure b) {
+ assert left.length == b.left.length : "Subtracted measure has to have length " + left.length;
+ assert left.length == b.right.length : "Subtracted measure has to have length " + left.length;
+
+ long[] leftRes = new long[left.length];
+ long[] rightRes = new long[left.length];
+
+ for (int i = 0; i < left.length; i++) {
+ leftRes[i] = left[i] - b.left[i];
+ rightRes[i] = right[i] - b.right[i];
+ }
+
+ return new GiniImpurityMeasure(leftRes, rightRes);
+ }
+
+ /** */
+ public long[] getLeft() {
+ return left;
+ }
+
+ /** */
+ public long[] getRight() {
+ return right;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
new file mode 100644
index 0000000..0dd0a10
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/GiniImpurityMeasureCalculator.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.impurity.gini;
+
+import java.util.Arrays;
+import java.util.Map;
+import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
+import org.apache.ignite.ml.tree.impurity.util.StepFunction;
+
+/**
+ * Gini impurity measure calculator.
+ */
+public class GiniImpurityMeasureCalculator implements ImpurityMeasureCalculator<GiniImpurityMeasure> {
+ /** */
+ private static final long serialVersionUID = -522995134128519679L;
+
+ /** Label encoder which defines integer value for every label class. */
+ private final Map<Double, Integer> lbEncoder;
+
+ /**
+ * Constructs a new instance of Gini impurity measure calculator.
+ *
+ * @param lbEncoder Label encoder which defines integer value for every label class.
+ */
+ public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder) {
+ this.lbEncoder = lbEncoder;
+ }
+
+ /** {@inheritDoc} */
+ @SuppressWarnings("unchecked")
+ @Override public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data) {
+ double[][] features = data.getFeatures();
+ double[] labels = data.getLabels();
+
+ if (features.length > 0) {
+ StepFunction<GiniImpurityMeasure>[] res = new StepFunction[features[0].length];
+
+ for (int col = 0; col < res.length; col++) {
+ data.sort(col);
+
+ double[] x = new double[features.length + 1];
+ GiniImpurityMeasure[] y = new GiniImpurityMeasure[features.length + 1];
+
+ int xPtr = 0, yPtr = 0;
+
+ long[] left = new long[lbEncoder.size()];
+ long[] right = new long[lbEncoder.size()];
+
+ for (int i = 0; i < labels.length; i++)
+ right[getLabelCode(labels[i])]++;
+
+ x[xPtr++] = Double.NEGATIVE_INFINITY;
+ y[yPtr++] = new GiniImpurityMeasure(
+ Arrays.copyOf(left, left.length),
+ Arrays.copyOf(right, right.length)
+ );
+
+ for (int i = 0; i < features.length; i++) {
+ left[getLabelCode(labels[i])]++;
+ right[getLabelCode(labels[i])]--;
+
+ if (i < (features.length - 1) && features[i + 1][col] == features[i][col])
+ continue;
+
+ x[xPtr++] = features[i][col];
+ y[yPtr++] = new GiniImpurityMeasure(
+ Arrays.copyOf(left, left.length),
+ Arrays.copyOf(right, right.length)
+ );
+ }
+
+ res[col] = new StepFunction<>(Arrays.copyOf(x, xPtr), Arrays.copyOf(y, yPtr));
+ }
+
+ return res;
+ }
+
+ return null;
+ }
+
+ /**
+ * Returns label code.
+ *
+ * @param lb Label.
+ * @return Label code.
+ */
+ int getLabelCode(double lb) {
+ Integer code = lbEncoder.get(lb);
+
+ assert code != null : "Can't find code for label " + lb;
+
+ return code;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java
new file mode 100644
index 0000000..d14cd92
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/impurity/gini/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Gini impurity measure and calculator.
+ */
+package org.apache.ignite.ml.tree.impurity.gini;
\ No newline at end of file