You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by is...@apache.org on 2017/11/10 15:07:12 UTC
[3/4] ignite git commit: IGNITE-5218: First version of decision
trees. This closes #2936
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java
new file mode 100644
index 0000000..e98bb72
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousRegionInfo.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+/**
+ * Information about region used by continuous features.
+ */
+public class ContinuousRegionInfo extends RegionInfo {
+ /**
+ * Count of samples in this region.
+ */
+ private int size;
+
+ /**
+ * @param impurity Impurity of the region.
+ * @param size Size of this region
+ */
+ public ContinuousRegionInfo(double impurity, int size) {
+ super(impurity);
+ this.size = size;
+ }
+
+ /**
+ * No-op constructor for serialization/deserialization.
+ */
+ public ContinuousRegionInfo() {
+ // No-op
+ }
+
+ /**
+ * Get the size of region.
+ */
+ public int getSize() {
+ return size;
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "ContinuousRegionInfo [" +
+ "size=" + size +
+ ']';
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeInt(size);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ super.readExternal(in);
+ size = in.readInt();
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java
new file mode 100644
index 0000000..f9b81d0
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/ContinuousSplitCalculator.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees;
+
+import java.util.stream.DoubleStream;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo;
+
+/**
+ * This class is used for calculation of best split by continuous feature.
+ *
+ * @param <C> Class in which information about region will be stored.
+ */
+public interface ContinuousSplitCalculator<C extends ContinuousRegionInfo> {
+ /**
+ * Calculate region info 'from scratch'.
+ *
+ * @param s Stream of labels in this region.
+ * @param l Index of sample projection on this feature in array sorted by this projection value and intervals
+ * bitsets. ({@see org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor}).
+ * @return Region info.
+ */
+ C calculateRegionInfo(DoubleStream s, int l);
+
+ /**
+ * Calculate split info of best split of region given information about this region.
+ *
+ * @param sampleIndexes Indexes of samples of this region.
+ * @param values All values of this feature.
+ * @param labels All labels of this feature.
+ * @param regionIdx Index of region being split.
+ * @param data Information about region being split which can be used for computations.
+ * @return Information about best split of region with index given by regionIdx.
+ */
+ SplitInfo<C> splitRegion(Integer[] sampleIndexes, double[] values, double[] labels, int regionIdx, C data);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java
new file mode 100644
index 0000000..8ec7db3
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/RegionInfo.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+/** Class containing information about region. */
+public class RegionInfo implements Externalizable {
+ /** Impurity in this region. */
+ private double impurity;
+
+ /**
+ * @param impurity Impurity of this region.
+ */
+ public RegionInfo(double impurity) {
+ this.impurity = impurity;
+ }
+
+ /**
+ * No-op constructor for serialization/deserialization.
+ */
+ public RegionInfo() {
+ // No-op
+ }
+
+ /**
+ * Get impurity in this region.
+ *
+ * @return Impurity of this region.
+ */
+ public double impurity() {
+ return impurity;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeDouble(impurity);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ impurity = in.readDouble();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
new file mode 100644
index 0000000..86e9326
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/DecisionTreeModel.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.models;
+
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.trees.nodes.DecisionTreeNode;
+
+/**
+ * Model for decision tree.
+ */
+public class DecisionTreeModel implements Model<Vector, Double> {
+ /** Root node of the decision tree. */
+ private final DecisionTreeNode root;
+
+ /**
+ * Construct decision tree model.
+ *
+ * @param root Root of decision tree.
+ */
+ public DecisionTreeModel(DecisionTreeNode root) {
+ this.root = root;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double predict(Vector val) {
+ return root.process(val);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java
new file mode 100644
index 0000000..ce8418e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/models/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains decision tree models.
+ */
+package org.apache.ignite.ml.trees.models;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java
new file mode 100644
index 0000000..cae6d4a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/CategoricalSplitNode.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.nodes;
+
+import java.util.BitSet;
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Split node by categorical feature.
+ */
+public class CategoricalSplitNode extends SplitNode {
+ /** Bitset specifying which categories belong to left subregion. */
+ private final BitSet bs;
+
+ /**
+ * Construct categorical split node.
+ *
+ * @param featureIdx Index of feature by which split is done.
+ * @param bs Bitset specifying which categories go to the left subtree.
+ */
+ public CategoricalSplitNode(int featureIdx, BitSet bs) {
+ super(featureIdx);
+ this.bs = bs;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean goLeft(Vector v) {
+ return bs.get((int)v.getX(featureIdx));
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "CategoricalSplitNode [bs=" + bs + ']';
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java
new file mode 100644
index 0000000..285cfcd
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/ContinuousSplitNode.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.nodes;
+
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Split node representing split of continuous feature.
+ */
+public class ContinuousSplitNode extends SplitNode {
+ /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */
+ private final double threshold;
+
+ /**
+ * Construct ContinuousSplitNode by threshold and feature index.
+ *
+ * @param threshold Threshold.
+ * @param featureIdx Feature index.
+ */
+ public ContinuousSplitNode(double threshold, int featureIdx) {
+ super(featureIdx);
+ this.threshold = threshold;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean goLeft(Vector v) {
+ return v.getX(featureIdx) <= threshold;
+ }
+
+ /** Threshold. Values which are less or equal then threshold are assigned to the left subregion. */
+ public double threshold() {
+ return threshold;
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "ContinuousSplitNode [" +
+ "threshold=" + threshold +
+ ']';
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java
new file mode 100644
index 0000000..d31623d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/DecisionTreeNode.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.nodes;
+
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Node of decision tree.
+ */
+public interface DecisionTreeNode {
+ /**
+ * Assign the double value to the given vector.
+ *
+ * @param v Vector.
+ * @return Value assigned to the given vector.
+ */
+ double process(Vector v);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java
new file mode 100644
index 0000000..79b441f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/Leaf.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.nodes;
+
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Terminal node of the decision tree.
+ */
+public class Leaf implements DecisionTreeNode {
+ /**
+ * Value in subregion represented by this node.
+ */
+ private final double val;
+
+ /**
+ * Construct the leaf of decision tree.
+ *
+ * @param val Value in subregion represented by this node.
+ */
+ public Leaf(double val) {
+ this.val = val;
+ }
+
+ /**
+ * Return value in subregion represented by this node.
+ *
+ * @param v Vector.
+ * @return Value in subregion represented by this node.
+ */
+ @Override public double process(Vector v) {
+ return val;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java
new file mode 100644
index 0000000..4c258d1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/SplitNode.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.nodes;
+
+import org.apache.ignite.ml.math.Vector;
+
+/**
+ * Node in decision tree representing a split.
+ */
+public abstract class SplitNode implements DecisionTreeNode {
+ /** Left subtree. */
+ protected DecisionTreeNode l;
+
+ /** Right subtree. */
+ protected DecisionTreeNode r;
+
+ /** Feature index. */
+ protected final int featureIdx;
+
+ /**
+ * Constructs SplitNode with a given feature index.
+ *
+ * @param featureIdx Feature index.
+ */
+ public SplitNode(int featureIdx) {
+ this.featureIdx = featureIdx;
+ }
+
+ /**
+ * Indicates if the given vector is in left subtree.
+ *
+ * @param v Vector
+ * @return Status of given vector being left subtree.
+ */
+ abstract boolean goLeft(Vector v);
+
+ /**
+ * Left subtree.
+ *
+ * @return Left subtree.
+ */
+ public DecisionTreeNode left() {
+ return l;
+ }
+
+ /**
+ * Right subtree.
+ *
+ * @return Right subtree.
+ */
+ public DecisionTreeNode right() {
+ return r;
+ }
+
+ /**
+ * Set the left subtree.
+ *
+ * @param n left subtree.
+ */
+ public void setLeft(DecisionTreeNode n) {
+ l = n;
+ }
+
+ /**
+ * Set the right subtree.
+ *
+ * @param n right subtree.
+ */
+ public void setRight(DecisionTreeNode n) {
+ r = n;
+ }
+
+ /**
+ * Delegates processing to subtrees.
+ *
+ * @param v Vector.
+ * @return Value assigned to the given vector.
+ */
+ @Override public double process(Vector v) {
+ if (left() != null && goLeft(v))
+ return left().process(v);
+ else
+ return right().process(v);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java
new file mode 100644
index 0000000..d6deb9d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/nodes/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains classes representing decision tree nodes.
+ */
+package org.apache.ignite.ml.trees.nodes;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java
new file mode 100644
index 0000000..b07ba4a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains decision tree algorithms.
+ */
+package org.apache.ignite.ml.trees;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java
new file mode 100644
index 0000000..0d27c8a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndex.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import org.apache.ignite.cache.affinity.AffinityKeyMapped;
+
+/**
+ * Class representing a simple index in 2d matrix in the form (row, col).
+ */
+public class BiIndex implements Externalizable {
+ /** Row. */
+ private int row;
+
+ /** Column. */
+ @AffinityKeyMapped
+ private int col;
+
+ /**
+ * No-op constructor for serialization/deserialization.
+ */
+ public BiIndex() {
+ // No-op.
+ }
+
+ /**
+ * Construct BiIndex from row and column.
+ *
+ * @param row Row.
+ * @param col Column.
+ */
+ public BiIndex(int row, int col) {
+ this.row = row;
+ this.col = col;
+ }
+
+ /**
+ * Returns row.
+ *
+ * @return Row.
+ */
+ public int row() {
+ return row;
+ }
+
+ /**
+ * Returns column.
+ *
+ * @return Column.
+ */
+ public int col() {
+ return col;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ BiIndex idx = (BiIndex)o;
+
+ if (row != idx.row)
+ return false;
+ return col == idx.col;
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ int res = row;
+ res = 31 * res + col;
+ return res;
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "BiIndex [" +
+ "row=" + row +
+ ", col=" + col +
+ ']';
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(row);
+ out.writeInt(col);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ row = in.readInt();
+ col = in.readInt();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java
new file mode 100644
index 0000000..04281fb
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/BiIndexedCacheColumnDecisionTreeTrainerInput.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.util.Map;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.lang.IgniteBiTuple;
+
+/**
+ * Adapter for column decision tree trainer for bi-indexed cache.
+ */
+public class BiIndexedCacheColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<BiIndex, Double> {
+ /**
+ * Construct an input for {@link ColumnDecisionTreeTrainer}.
+ *
+ * @param cache Bi-indexed cache.
+ * @param catFeaturesInfo Information about categorical feature in the form (feature index -> number of
+ * categories).
+ * @param samplesCnt Count of samples.
+ * @param featuresCnt Count of features.
+ */
+ public BiIndexedCacheColumnDecisionTreeTrainerInput(IgniteCache<BiIndex, Double> cache,
+ Map<Integer, Integer> catFeaturesInfo, int samplesCnt, int featuresCnt) {
+ super(cache,
+ () -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, featuresCnt)),
+ e -> Stream.of(new IgniteBiTuple<>(e.getKey().row(), e.getValue())),
+ DoubleStream::of,
+ fIdx -> IntStream.range(0, samplesCnt).mapToObj(s -> new BiIndex(s, fIdx)),
+ catFeaturesInfo,
+ featuresCnt,
+ samplesCnt);
+ }
+
+ /** {@inheritDoc} */
+ @Override public Object affinityKey(int idx, Ignite ignite) {
+ return idx;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java
new file mode 100644
index 0000000..9518caf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/CacheColumnDecisionTreeTrainerInput.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
+/**
+ * Adapter of a given cache to {@see CacheColumnDecisionTreeTrainerInput}
+ *
+ * @param <K> Class of keys of the cache.
+ * @param <V> Class of values of the cache.
+ */
+public abstract class CacheColumnDecisionTreeTrainerInput<K, V> implements ColumnDecisionTreeTrainerInput {
+ /** Supplier of labels key. */
+ private final IgniteSupplier<Stream<K>> labelsKeys;
+
+ /** Count of features. */
+ private final int featuresCnt;
+
+ /** Function which maps feature index to Stream of keys corresponding to this feature index. */
+ private final IgniteFunction<Integer, Stream<K>> keyMapper;
+
+ /** Information about which features are categorical in form of feature index -> number of categories. */
+ private final Map<Integer, Integer> catFeaturesInfo;
+
+ /** Cache name. */
+ private final String cacheName;
+
+ /** Count of samples. */
+ private final int samplesCnt;
+
+ /** Function used for mapping cache values to stream of tuples. */
+ private final IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper;
+
+ /**
+ * Function which map value of entry with label key to DoubleStream.
+ * Look at {@code CacheColumnDecisionTreeTrainerInput::labels} for understanding how {@code labelsKeys} and
+ * {@code labelsMapper} interact.
+ */
+ private final IgniteFunction<V, DoubleStream> labelsMapper;
+
+ /**
+ * Constructs input for {@see org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer}.
+ *
+ * @param c Cache.
+ * @param valuesMapper Function for mapping cache entry to stream used by {@link
+ * org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer}.
+ * @param labelsMapper Function used for mapping cache value to labels array.
+ * @param keyMapper Function used for mapping feature index to the cache key.
+ * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of
+ * categories.
+ * @param featuresCnt Count of features.
+ * @param samplesCnt Count of samples.
+ */
+ // TODO: IGNITE-5724 think about boxing/unboxing
+ public CacheColumnDecisionTreeTrainerInput(IgniteCache<K, V> c,
+ IgniteSupplier<Stream<K>> labelsKeys,
+ IgniteFunction<Cache.Entry<K, V>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper,
+ IgniteFunction<V, DoubleStream> labelsMapper,
+ IgniteFunction<Integer, Stream<K>> keyMapper,
+ Map<Integer, Integer> catFeaturesInfo,
+ int featuresCnt, int samplesCnt) {
+
+ cacheName = c.getName();
+ this.labelsKeys = labelsKeys;
+ this.valuesMapper = valuesMapper;
+ this.labelsMapper = labelsMapper;
+ this.keyMapper = keyMapper;
+ this.catFeaturesInfo = catFeaturesInfo;
+ this.samplesCnt = samplesCnt;
+ this.featuresCnt = featuresCnt;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Stream<IgniteBiTuple<Integer, Double>> values(int idx) {
+ return cache(Ignition.localIgnite()).getAll(keyMapper.apply(idx).collect(Collectors.toSet())).
+ entrySet().
+ stream().
+ flatMap(ent -> valuesMapper.apply(new CacheEntryImpl<>(ent.getKey(), ent.getValue())));
+ }
+
+ /** {@inheritDoc} */
+ @Override public double[] labels(Ignite ignite) {
+ return labelsKeys.get().map(k -> get(k, ignite)).flatMapToDouble(labelsMapper).toArray();
+ }
+
+ /** {@inheritDoc} */
+ @Override public Map<Integer, Integer> catFeaturesInfo() {
+ return catFeaturesInfo;
+ }
+
+ /** {@inheritDoc} */
+ @Override public int featuresCount() {
+ return featuresCnt;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Object affinityKey(int idx, Ignite ignite) {
+ return ignite.affinity(cacheName).affinityKey(keyMapper.apply(idx));
+ }
+
+ /** */
+ private V get(K k, Ignite ignite) {
+ V res = cache(ignite).localPeek(k);
+
+ if (res == null)
+ res = cache(ignite).get(k);
+
+ return res;
+ }
+
+ /** */
+ private IgniteCache<K, V> cache(Ignite ignite) {
+ return ignite.getOrCreateCache(cacheName);
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
new file mode 100644
index 0000000..32e33f3
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
@@ -0,0 +1,557 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import com.zaxxer.sparsebits.SparseBitSet;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.CachePeekMode;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.cluster.ClusterNode;
+import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
+import org.apache.ignite.internal.util.typedef.X;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.distributed.CacheUtils;
+import org.apache.ignite.ml.math.functions.Functions;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.trees.ContinuousRegionInfo;
+import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
+import org.apache.ignite.ml.trees.models.DecisionTreeModel;
+import org.apache.ignite.ml.trees.nodes.DecisionTreeNode;
+import org.apache.ignite.ml.trees.nodes.Leaf;
+import org.apache.ignite.ml.trees.nodes.SplitNode;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo;
+import org.jetbrains.annotations.NotNull;
+
+import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey;
+
+/**
+ * This trainer stores observations as columns and features as rows.
+ * Ideas from https://github.com/fabuzaid21/yggdrasil are used here.
+ */
+public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implements
+ Trainer<DecisionTreeModel, ColumnDecisionTreeTrainerInput> {
+ /**
+ * Function used to assign a value to a region.
+ */
+ private final IgniteFunction<DoubleStream, Double> regCalc;
+
+ /**
+ * Function used to calculate impurity in regions used by categorical features.
+ */
+ private final IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider;
+
+ /**
+ * Categorical calculator provider.
+ **/
+ private final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider;
+
+ /**
+ * Cache used for storing data for training.
+ */
+ private IgniteCache<RegionKey, List<RegionProjection>> prjsCache;
+
+ /**
+ * Minimal information gain.
+ */
+ private static final double MIN_INFO_GAIN = 1E-10;
+
+ /**
+ * Maximal depth of the decision tree.
+ */
+ private final int maxDepth;
+
+ /**
+ * Size of block which is used for storing regions in cache.
+ */
+ private static final int BLOCK_SIZE = 1 << 4;
+
+ /** Ignite instance. */
+ private final Ignite ignite;
+
+ /**
+ * Construct {@link ColumnDecisionTreeTrainer}.
+ *
+ * @param maxDepth Maximal depth of the decision tree.
+ * @param continuousCalculatorProvider Provider of calculator of splits for region projection on continuous
+ * features.
+ * @param categoricalCalculatorProvider Provider of calculator of splits for region projection on categorical
+ * features.
+ * @param regCalc Function used to assign a value to a region.
+ */
+ public ColumnDecisionTreeTrainer(int maxDepth,
+ IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider,
+ IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider,
+ IgniteFunction<DoubleStream, Double> regCalc,
+ Ignite ignite) {
+ this.maxDepth = maxDepth;
+ this.continuousCalculatorProvider = continuousCalculatorProvider;
+ this.categoricalCalculatorProvider = categoricalCalculatorProvider;
+ this.regCalc = regCalc;
+ this.ignite = ignite;
+ }
+
+ /**
+ * Utility class used to get index of feature by which split is done and split info.
+ */
+ private static class IndexAndSplitInfo {
+ /**
+ * Index of feature by which split is done.
+ */
+ private final int featureIdx;
+
+ /**
+ * Split information.
+ */
+ private final SplitInfo info;
+
+ /**
+ * @param featureIdx Index of feature by which split is done.
+ * @param info Split information.
+ */
+ IndexAndSplitInfo(int featureIdx, SplitInfo info) {
+ this.featureIdx = featureIdx;
+ this.info = info;
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "IndexAndSplitInfo [featureIdx=" + featureIdx + ", info=" + info + ']';
+ }
+ }
+
+ /**
+ * Utility class used to build decision tree. Basically it is pointer to leaf node.
+ */
+ private static class TreeTip {
+ /** */
+ private Consumer<DecisionTreeNode> leafSetter;
+
+ /** */
+ private int depth;
+
+ /** */
+ TreeTip(Consumer<DecisionTreeNode> leafSetter, int depth) {
+ this.leafSetter = leafSetter;
+ this.depth = depth;
+ }
+ }
+
+ /**
+ * Utility class used as decision tree root node.
+ */
+ private static class RootNode implements DecisionTreeNode {
+ /** */
+ private DecisionTreeNode s;
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override public double process(Vector v) {
+ return s.process(v);
+ }
+
+ /** */
+ void setSplit(DecisionTreeNode s) {
+ this.s = s;
+ }
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) {
+ prjsCache = ProjectionsCache.getOrCreate(ignite);
+ IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite);
+ SplitCache.getOrCreate(ignite);
+
+ UUID trainingUUID = UUID.randomUUID();
+
+ TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite);
+ ctxtCache.put(trainingUUID, ct);
+
+ CacheUtils.bcast(prjsCache.getName(), ignite, () -> {
+ Ignite ignite = Ignition.localIgnite();
+ IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite);
+ IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
+
+ Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME);
+
+ ClusterNode locNode = ignite.cluster().localNode();
+
+ Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>();
+ Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>();
+
+ targetAffinity.
+ mapKeysToNodes(IntStream.range(0, i.featuresCount()).
+ mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)).
+ collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()).
+ forEach(k -> {
+ FeatureProcessor vec;
+
+ int featureIdx = k.featureIdx();
+
+ IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
+ TrainingContext ctx = ctxCache.get(trainingUUID);
+ double[] vals = new double[ctx.labels().length];
+
+ vec = ctx.featureProcessor(featureIdx);
+ i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2());
+
+ fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals);
+
+ List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE);
+ newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels()));
+ pm.put(k, newReg);
+ });
+
+ featuresCache.putAll(fm);
+ projCache.putAll(pm);
+
+ return null;
+ });
+
+ return doTrain(i, trainingUUID);
+ }
+
+ /**
+ * Get samples array.
+ *
+ * @param values Stream of tuples in the form of (index, value).
+ * @param size size of stream.
+ * @return Samples array.
+ */
+ private Integer[] getSamples(Stream<IgniteBiTuple<Integer, Double>> values, int size) {
+ Integer[] res = new Integer[size];
+
+ values.forEach(v -> res[v.get1()] = v.get1());
+
+ return res;
+ }
+
+ /** */
+ @NotNull
+ private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
+ RootNode root = new RootNode();
+
+ // List containing setters of leaves of the tree.
+ List<TreeTip> tips = new LinkedList<>();
+ tips.add(new TreeTip(root::setSplit, 0));
+
+ int curDepth = 0;
+ int regsCnt = 1;
+
+ int featuresCnt = input.featuresCount();
+ IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).
+ forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
+ updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
+
+ // TODO: IGNITE-5893 Currently if the best split makes tree deeper than max depth process will be terminated, but actually we should
+ // only stop when *any* improving split makes tree deeper than max depth. Can be fixed if we will store which
+ // regions cannot be split more and split only those that can.
+ while (true) {
+ long before = System.currentTimeMillis();
+
+ IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
+
+ long findBestRegIdx = System.currentTimeMillis() - before;
+
+ Integer bestFeatureIdx = b.get1();
+
+ Integer regIdx = b.get2().get1();
+ Double bestInfoGain = b.get2().get2();
+
+ if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
+ before = System.currentTimeMillis();
+
+ SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME,
+ input.affinityKey(bestFeatureIdx, ignite),
+ () -> {
+ TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
+ Ignite ignite = Ignition.localIgnite();
+ RegionKey key = ProjectionsCache.key(bestFeatureIdx,
+ regIdx / BLOCK_SIZE,
+ input.affinityKey(bestFeatureIdx, Ignition.localIgnite()),
+ uuid);
+ RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
+ return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
+ });
+
+ long findBestSplit = System.currentTimeMillis() - before;
+
+ IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
+
+ regsCnt++;
+
+ X.println(">>> Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
+ // Request bitset for split region.
+ int ind = best.info.regionIndex();
+
+ SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME,
+ input.affinityKey(bestFeatureIdx, ignite),
+ () -> {
+ Ignite ignite = Ignition.localIgnite();
+ IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
+ IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
+ TrainingContext ctx = ctxCache.localPeek(uuid);
+
+ double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
+ RegionKey key = ProjectionsCache.key(bestFeatureIdx,
+ regIdx / BLOCK_SIZE,
+ input.affinityKey(bestFeatureIdx, Ignition.localIgnite()),
+ uuid);
+ RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
+ return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
+
+ });
+
+ SplitNode sn = best.info.createSplitNode(best.featureIdx);
+
+ TreeTip tipToSplit = tips.get(ind);
+ tipToSplit.leafSetter.accept(sn);
+ tipToSplit.leafSetter = sn::setLeft;
+ int d = tipToSplit.depth++;
+ tips.add(new TreeTip(sn::setRight, d));
+
+ if (d > curDepth) {
+ curDepth = d;
+ X.println(">>> Depth: " + curDepth);
+ X.println(">>> Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
+ }
+
+ before = System.currentTimeMillis();
+ // Perform split on all feature vectors.
+ IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).
+ mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).
+ collect(Collectors.toSet());
+
+ int rc = regsCnt;
+
+ // Perform split.
+ CacheUtils.update(prjsCache.getName(), ignite,
+ (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
+ RegionKey k = e.getKey();
+
+ List<RegionProjection> leftBlock = e.getValue();
+
+ int fIdx = k.featureIdx();
+ int idxInBlock = ind % BLOCK_SIZE;
+
+ IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
+ TrainingContext<D> ctx = ctxCache.get(uuid);
+
+ RegionProjection targetRegProj = leftBlock.get(idxInBlock);
+
+ IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.
+ performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
+
+ RegionProjection left = regs.get1();
+ RegionProjection right = regs.get2();
+
+ leftBlock.set(idxInBlock, left);
+ RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
+
+ IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
+
+ List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
+
+ if (rightBlock == null) {
+ List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
+ newBlock.add(right);
+ return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
+ }
+ else {
+ rightBlock.add(right);
+ return rightBlock.equals(k) ?
+ Stream.of(new CacheEntryImpl<>(k, leftBlock)) :
+ Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
+ }
+ },
+ bestRegsKeys);
+
+ X.println(">>> Update of projs cache took " + (System.currentTimeMillis() - before));
+
+ before = System.currentTimeMillis();
+
+ updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
+
+ X.println(">>> Update of split cache took " + (System.currentTimeMillis() - before));
+ }
+ else {
+ X.println(">>> Best feature index: " + bestFeatureIdx + ", best infoGain " + bestInfoGain);
+ break;
+ }
+ }
+
+ int rc = regsCnt;
+
+ IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
+ IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
+
+ return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).
+ mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).
+ map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>)new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
+ };
+
+ Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite,
+ (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
+ int regBlockIdx = e.getKey().regionBlockIndex();
+
+ if (e.getValue() != null) {
+ for (int i = 0; i < e.getValue().size(); i++) {
+ int regIdx = regBlockIdx * BLOCK_SIZE + i;
+ RegionProjection reg = e.getValue().get(i);
+
+ Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
+ m.put(regIdx, res);
+ }
+ }
+
+ return m;
+ },
+ () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid),
+ featZeroRegs,
+ (infos, infos2) -> {
+ Map<Integer, Double> res = new HashMap<>();
+ res.putAll(infos);
+ res.putAll(infos2);
+ return res;
+ },
+ HashMap::new
+ );
+
+ int i = 0;
+ for (TreeTip tip : tips) {
+ tip.leafSetter.accept(new Leaf(vals.get(i)));
+ i++;
+ }
+
+ ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
+ ContextCache.getOrCreate(ignite).remove(uuid);
+ FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
+ SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
+
+ return new DecisionTreeModel(root.s);
+ }
+
+ /**
+ * Find the best split in the form (feature index, (index of region with the best split, impurity of region with the
+ * best split)).
+ *
+ * @param featuresCnt Count of features.
+ * @param affinity Affinity function.
+ * @param trainingUUID UUID of training.
+ * @return Best split in the form (feature index, (index of region with the best split, impurity of region with the
+ * best split)).
+ */
+ private IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> findBestSplitIndexForFeatures(int featuresCnt,
+ IgniteBiFunction<Integer, Ignite, Object> affinity,
+ UUID trainingUUID) {
+ Set<Integer> featureIndexes = IntStream.range(0, featuresCnt).boxed().collect(Collectors.toSet());
+
+ return CacheUtils.reduce(SplitCache.CACHE_NAME, ignite,
+ (Object ctx, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e, IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> r) ->
+ Functions.MAX_GENERIC(new IgniteBiTuple<>(e.getKey().featureIdx(), e.getValue()), r, comparator()),
+ () -> null,
+ () -> SplitCache.localEntries(featureIndexes, affinity, trainingUUID),
+ (i1, i2) -> Functions.MAX_GENERIC(i1, i2, Comparator.comparingDouble(bt -> bt.get2().get2())),
+ () -> new IgniteBiTuple<>(-1, new IgniteBiTuple<>(-1, Double.NEGATIVE_INFINITY))
+ );
+ }
+
+ /** */
+ private static Comparator<IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>>> comparator() {
+ return Comparator.comparingDouble(bt -> bt != null && bt.get2() != null ? bt.get2().get2() : Double.NEGATIVE_INFINITY);
+ }
+
+ /**
+ * Update split cache.
+ *
+ * @param lastSplitRegionIdx Index of region which had last best split.
+ * @param regsCnt Count of regions.
+ * @param featuresCnt Count of features.
+ * @param affinity Affinity function.
+ * @param trainingUUID UUID of current training.
+ */
+ private void updateSplitCache(int lastSplitRegionIdx, int regsCnt, int featuresCnt,
+ IgniteCurriedBiFunction<Ignite, Integer, Object> affinity,
+ UUID trainingUUID) {
+ CacheUtils.update(SplitCache.CACHE_NAME, ignite,
+ (Ignite ign, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e) -> {
+ Integer bestRegIdx = e.getValue().get1();
+ int fIdx = e.getKey().featureIdx();
+ TrainingContext ctx = ContextCache.getOrCreate(ign).get(trainingUUID);
+
+ Map<Integer, RegionProjection> toCompare;
+
+ // Fully recalculate best.
+ if (bestRegIdx == lastSplitRegionIdx)
+ toCompare = ProjectionsCache.projectionsOfFeature(fIdx, maxDepth, regsCnt, BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign);
+ // Just compare previous best and two regions which are produced by split.
+ else
+ toCompare = ProjectionsCache.projectionsOfRegions(fIdx, maxDepth,
+ IntStream.of(bestRegIdx, lastSplitRegionIdx, regsCnt - 1), BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign);
+
+ double[] values = ctx.values(fIdx, ign);
+ double[] labels = ctx.labels();
+
+ IgniteBiTuple<Integer, Double> max = toCompare.entrySet().stream().
+ map(ent -> {
+ SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey());
+ return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY);
+ }).
+ max(Comparator.comparingDouble(IgniteBiTuple::get2)).
+ get();
+
+ return Stream.of(new CacheEntryImpl<>(e.getKey(), max));
+ },
+ () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet())
+ );
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
new file mode 100644
index 0000000..94331f7
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.util.Map;
+import java.util.stream.Stream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.lang.IgniteBiTuple;
+
+/**
+ * Input for {@see ColumnDecisionTreeTrainer}.
+ */
+public interface ColumnDecisionTreeTrainerInput {
+ /**
+ * Projection of data on feature with the given index.
+ *
+ * @param idx Feature index.
+ * @return Projection of data on feature with the given index.
+ */
+ Stream<IgniteBiTuple<Integer, Double>> values(int idx);
+
+ /**
+ * Labels.
+ *
+ * @param ignite Ignite instance.
+ */
+ double[] labels(Ignite ignite);
+
+ /** Information about which features are categorical in the form of feature index -> number of categories. */
+ Map<Integer, Integer> catFeaturesInfo();
+
+ /** Number of features. */
+ int featuresCount();
+
+ /**
+ * Get affinity key for the given column index.
+ * Affinity key should be pure-functionally dependent from idx.
+ */
+ Object affinityKey(int idx, Ignite ignite);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
new file mode 100644
index 0000000..9a11902
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
+import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
+import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Adapter of SparseDistributedMatrix to ColumnDecisionTreeTrainerInput.
+ * Sparse SparseDistributedMatrix should be in {@see org.apache.ignite.ml.math.StorageConstants#COLUMN_STORAGE_MODE} and
+ * should contain samples in rows last position in row being label of this sample.
+ */
+public class MatrixColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<RowColMatrixKey, Map<Integer, Double>> {
+ /**
+ * @param m Sparse SparseDistributedMatrix should be in {@see org.apache.ignite.ml.math.StorageConstants#COLUMN_STORAGE_MODE}
+ * containing samples in rows last position in row being label of this sample.
+ * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of
+ * categories.
+ */
+ public MatrixColumnDecisionTreeTrainerInput(SparseDistributedMatrix m, Map<Integer, Integer> catFeaturesInfo) {
+ super(((SparseDistributedMatrixStorage)m.getStorage()).cache(),
+ () -> Stream.of(new SparseMatrixKey(m.columnSize() - 1, m.getUUID(), m.columnSize() - 1)),
+ valuesMapper(m),
+ labels(m),
+ keyMapper(m),
+ catFeaturesInfo,
+ m.columnSize() - 1,
+ m.rowSize());
+ }
+
+ /** Values mapper. See {@link CacheColumnDecisionTreeTrainerInput#valuesMapper} */
+ @NotNull
+ private static IgniteFunction<Cache.Entry<RowColMatrixKey, Map<Integer, Double>>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper(
+ SparseDistributedMatrix m) {
+ return ent -> {
+ Map<Integer, Double> map = ent.getValue() != null ? ent.getValue() : new HashMap<>();
+ return IntStream.range(0, m.rowSize()).mapToObj(k -> new IgniteBiTuple<>(k, map.getOrDefault(k, 0.0)));
+ };
+ }
+
+ /** Key mapper. See {@link CacheColumnDecisionTreeTrainerInput#keyMapper} */
+ @NotNull private static IgniteFunction<Integer, Stream<RowColMatrixKey>> keyMapper(SparseDistributedMatrix m) {
+ return i -> Stream.of(new SparseMatrixKey(i, ((SparseDistributedMatrixStorage)m.getStorage()).getUUID(), i));
+ }
+
+ /** Labels mapper. See {@link CacheColumnDecisionTreeTrainerInput#labelsMapper} */
+ @NotNull private static IgniteFunction<Map<Integer, Double>, DoubleStream> labels(SparseDistributedMatrix m) {
+ return mp -> IntStream.range(0, m.rowSize()).mapToDouble(k -> mp.getOrDefault(k, 0.0));
+ }
+
+ /** {@inheritDoc} */
+ @Override public Object affinityKey(int idx, Ignite ignite) {
+ return idx;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
new file mode 100644
index 0000000..e95f57b
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+import org.apache.ignite.ml.trees.RegionInfo;
+
+/**
+ * Projection of region on given feature.
+ *
+ * @param <D> Data of region.
+ */
+public class RegionProjection<D extends RegionInfo> implements Externalizable {
+ /** Samples projections. */
+ protected Integer[] sampleIndexes;
+
+ /** Region data */
+ protected D data;
+
+ /** Depth of this region. */
+ protected int depth;
+
+ /**
+ * @param sampleIndexes Samples indexes.
+ * @param data Region data.
+ * @param depth Depth of this region.
+ */
+ public RegionProjection(Integer[] sampleIndexes, D data, int depth) {
+ this.data = data;
+ this.depth = depth;
+ this.sampleIndexes = sampleIndexes;
+ }
+
+ /**
+ * No-op constructor used for serialization/deserialization.
+ */
+ public RegionProjection() {
+ // No-op.
+ }
+
+ /**
+ * Get samples indexes.
+ *
+ * @return Samples indexes.
+ */
+ public Integer[] sampleIndexes() {
+ return sampleIndexes;
+ }
+
+ /**
+ * Get region data.
+ *
+ * @return Region data.
+ */
+ public D data() {
+ return data;
+ }
+
+ /**
+ * Get region depth.
+ *
+ * @return Region depth.
+ */
+ public int depth() {
+ return depth;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(sampleIndexes.length);
+
+ for (Integer sampleIndex : sampleIndexes)
+ out.writeInt(sampleIndex);
+
+ out.writeObject(data);
+ out.writeInt(depth);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ int size = in.readInt();
+
+ sampleIndexes = new Integer[size];
+
+ for (int i = 0; i < size; i++)
+ sampleIndexes[i] = in.readInt();
+
+ data = (D)in.readObject();
+ depth = in.readInt();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
new file mode 100644
index 0000000..6415dab
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased;
+
+import com.zaxxer.sparsebits.SparseBitSet;
+import java.util.Map;
+import java.util.UUID;
+import java.util.stream.DoubleStream;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.trees.ContinuousRegionInfo;
+import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
+import org.apache.ignite.ml.trees.RegionInfo;
+import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.CategoricalFeatureProcessor;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor;
+import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor;
+
+import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME;
+
+/**
+ * Context of training with {@link ColumnDecisionTreeTrainer}.
+ *
+ * @param <D> Class for storing of information used in calculation of impurity of continuous feature region.
+ */
+public class TrainingContext<D extends ContinuousRegionInfo> {
+ /** Input for training with {@link ColumnDecisionTreeTrainer}. */
+ private final ColumnDecisionTreeTrainerInput input;
+
+ /** Labels. */
+ private final double[] labels;
+
+ /** Calculator used for finding splits of region of continuous features. */
+ private final ContinuousSplitCalculator<D> continuousSplitCalculator;
+
+ /** Calculator used for finding splits of region of categorical feature. */
+ private final IgniteFunction<DoubleStream, Double> categoricalSplitCalculator;
+
+ /** UUID of current training. */
+ private final UUID trainingUUID;
+
+ /**
+ * Construct context for training with {@link ColumnDecisionTreeTrainer}.
+ *
+ * @param input Input for training.
+ * @param continuousSplitCalculator Calculator used for calculations of splits of continuous features regions.
+ * @param categoricalSplitCalculator Calculator used for calculations of splits of categorical features regions.
+ * @param trainingUUID UUID of the current training.
+ * @param ignite Ignite instance.
+ */
+ public TrainingContext(ColumnDecisionTreeTrainerInput input,
+ ContinuousSplitCalculator<D> continuousSplitCalculator,
+ IgniteFunction<DoubleStream, Double> categoricalSplitCalculator,
+ UUID trainingUUID,
+ Ignite ignite) {
+ this.input = input;
+ this.labels = input.labels(ignite);
+ this.continuousSplitCalculator = continuousSplitCalculator;
+ this.categoricalSplitCalculator = categoricalSplitCalculator;
+ this.trainingUUID = trainingUUID;
+ }
+
+ /**
+ * Get processor used for calculating splits of categorical features.
+ *
+ * @param catsCnt Count of categories.
+ * @return Processor used for calculating splits of categorical features.
+ */
+ public CategoricalFeatureProcessor categoricalFeatureProcessor(int catsCnt) {
+ return new CategoricalFeatureProcessor(categoricalSplitCalculator, catsCnt);
+ }
+
+ /**
+ * Get processor used for calculating splits of continuous features.
+ *
+ * @return Processor used for calculating splits of continuous features.
+ */
+ public ContinuousFeatureProcessor<D> continuousFeatureProcessor() {
+ return new ContinuousFeatureProcessor<>(continuousSplitCalculator);
+ }
+
+ /**
+ * Get labels.
+ *
+ * @return Labels.
+ */
+ public double[] labels() {
+ return labels;
+ }
+
+ /**
+ * Get values of feature with given index.
+ *
+ * @param featIdx Feature index.
+ * @param ignite Ignite instance.
+ * @return Values of feature with given index.
+ */
+ public double[] values(int featIdx, Ignite ignite) {
+ IgniteCache<FeaturesCache.FeatureKey, double[]> featuresCache = ignite.getOrCreateCache(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME);
+ return featuresCache.localPeek(FeaturesCache.getFeatureCacheKey(featIdx, trainingUUID, input.affinityKey(featIdx, ignite)));
+ }
+
+ /**
+ * Perform best split on the given region projection.
+ *
+ * @param input Input of {@link ColumnDecisionTreeTrainer} performing split.
+ * @param bitSet Bit set specifying split.
+ * @param targetFeatIdx Index of feature for performing split.
+ * @param bestFeatIdx Index of feature with best split.
+ * @param targetRegionPrj Projection of region to split on feature with index {@code featureIdx}.
+ * @param leftData Data of left region of split.
+ * @param rightData Data of right region of split.
+ * @param ignite Ignite instance.
+ * @return Perform best split on the given region projection.
+ */
+ public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(ColumnDecisionTreeTrainerInput input,
+ SparseBitSet bitSet, int targetFeatIdx, int bestFeatIdx, RegionProjection targetRegionPrj, RegionInfo leftData,
+ RegionInfo rightData, Ignite ignite) {
+
+ Map<Integer, Integer> catFeaturesInfo = input.catFeaturesInfo();
+
+ if (!catFeaturesInfo.containsKey(targetFeatIdx) && !catFeaturesInfo.containsKey(bestFeatIdx))
+ return continuousFeatureProcessor().performSplit(bitSet, targetRegionPrj, (D)leftData, (D)rightData);
+ else if (catFeaturesInfo.containsKey(targetFeatIdx))
+ return categoricalFeatureProcessor(catFeaturesInfo.get(targetFeatIdx)).performSplitGeneric(bitSet, values(targetFeatIdx, ignite), targetRegionPrj, leftData, rightData);
+ return continuousFeatureProcessor().performSplitGeneric(bitSet, labels, targetRegionPrj, leftData, rightData);
+ }
+
+ /**
+ * Processor used for calculating splits for feature with the given index.
+ *
+ * @param featureIdx Index of feature to process.
+ * @return Processor used for calculating splits for feature with the given index.
+ */
+ public FeatureProcessor featureProcessor(int featureIdx) {
+ return input.catFeaturesInfo().containsKey(featureIdx) ? categoricalFeatureProcessor(input.catFeaturesInfo().get(featureIdx)) : continuousFeatureProcessor();
+ }
+
+ /**
+ * Shortcut for affinity key.
+ *
+ * @param idx Feature index.
+ * @return Affinity key.
+ */
+ public Object affinityKey(int idx) {
+ return input.affinityKey(idx, Ignition.localIgnite());
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/db7697b1/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
new file mode 100644
index 0000000..51ea359
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trees.trainers.columnbased.caches;
+
+import java.util.UUID;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.CacheAtomicityMode;
+import org.apache.ignite.cache.CacheMode;
+import org.apache.ignite.cache.CacheWriteSynchronizationMode;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.trees.ContinuousRegionInfo;
+import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
+import org.apache.ignite.ml.trees.trainers.columnbased.TrainingContext;
+
+/**
+ * Class for operations related to cache containing training context for {@link ColumnDecisionTreeTrainer}.
+ */
+public class ContextCache {
+ /**
+ * Name of cache containing training context for {@link ColumnDecisionTreeTrainer}.
+ */
+ public static final String COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME";
+
+ /**
+ * Get or create cache for training context.
+ *
+ * @param ignite Ignite instance.
+ * @param <D> Class storing information about continuous regions.
+ * @return Cache for training context.
+ */
+ public static <D extends ContinuousRegionInfo> IgniteCache<UUID, TrainingContext<D>> getOrCreate(Ignite ignite) {
+ CacheConfiguration<UUID, TrainingContext<D>> cfg = new CacheConfiguration<>();
+
+ cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC);
+
+ cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
+
+ cfg.setEvictionPolicy(null);
+
+ cfg.setCopyOnRead(false);
+
+ cfg.setCacheMode(CacheMode.REPLICATED);
+
+ cfg.setOnheapCacheEnabled(true);
+
+ cfg.setReadFromBackup(true);
+
+ cfg.setName(COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME);
+
+ return ignite.getOrCreateCache(cfg);
+ }
+}