You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ag...@apache.org on 2018/04/13 09:33:28 UTC
[16/54] [abbrv] ignite git commit: IGNITE-8059: Integrate decision
tree with partition based dataset.
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
deleted file mode 100644
index fec0a83..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
+++ /dev/null
@@ -1,568 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.UUID;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Consumer;
-import java.util.stream.Collectors;
-import java.util.stream.DoubleStream;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
-import javax.cache.Cache;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.IgniteLogger;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CachePeekMode;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.cluster.ClusterNode;
-import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.Trainer;
-import org.apache.ignite.ml.math.Vector;
-import org.apache.ignite.ml.math.distributed.CacheUtils;
-import org.apache.ignite.ml.math.functions.Functions;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
-import org.apache.ignite.ml.trees.models.DecisionTreeModel;
-import org.apache.ignite.ml.trees.nodes.DecisionTreeNode;
-import org.apache.ignite.ml.trees.nodes.Leaf;
-import org.apache.ignite.ml.trees.nodes.SplitNode;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.ContextCache;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.FeatureKey;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.ProjectionsCache.RegionKey;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.SplitCache.SplitKey;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo;
-import org.jetbrains.annotations.NotNull;
-
-import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.getFeatureCacheKey;
-
-/**
- * This trainer stores observations as columns and features as rows.
- * Ideas from https://github.com/fabuzaid21/yggdrasil are used here.
- */
-public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implements
- Trainer<DecisionTreeModel, ColumnDecisionTreeTrainerInput> {
- /**
- * Function used to assign a value to a region.
- */
- private final IgniteFunction<DoubleStream, Double> regCalc;
-
- /**
- * Function used to calculate impurity in regions used by categorical features.
- */
- private final IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider;
-
- /**
- * Categorical calculator provider.
- **/
- private final IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider;
-
- /**
- * Cache used for storing data for training.
- */
- private IgniteCache<RegionKey, List<RegionProjection>> prjsCache;
-
- /**
- * Minimal information gain.
- */
- private static final double MIN_INFO_GAIN = 1E-10;
-
- /**
- * Maximal depth of the decision tree.
- */
- private final int maxDepth;
-
- /**
- * Size of block which is used for storing regions in cache.
- */
- private static final int BLOCK_SIZE = 1 << 4;
-
- /** Ignite instance. */
- private final Ignite ignite;
-
- /** Logger */
- private final IgniteLogger log;
-
- /**
- * Construct {@link ColumnDecisionTreeTrainer}.
- *
- * @param maxDepth Maximal depth of the decision tree.
- * @param continuousCalculatorProvider Provider of calculator of splits for region projection on continuous
- * features.
- * @param categoricalCalculatorProvider Provider of calculator of splits for region projection on categorical
- * features.
- * @param regCalc Function used to assign a value to a region.
- */
- public ColumnDecisionTreeTrainer(int maxDepth,
- IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> continuousCalculatorProvider,
- IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> categoricalCalculatorProvider,
- IgniteFunction<DoubleStream, Double> regCalc,
- Ignite ignite) {
- this.maxDepth = maxDepth;
- this.continuousCalculatorProvider = continuousCalculatorProvider;
- this.categoricalCalculatorProvider = categoricalCalculatorProvider;
- this.regCalc = regCalc;
- this.ignite = ignite;
- this.log = ignite.log();
- }
-
- /**
- * Utility class used to get index of feature by which split is done and split info.
- */
- private static class IndexAndSplitInfo {
- /**
- * Index of feature by which split is done.
- */
- private final int featureIdx;
-
- /**
- * Split information.
- */
- private final SplitInfo info;
-
- /**
- * @param featureIdx Index of feature by which split is done.
- * @param info Split information.
- */
- IndexAndSplitInfo(int featureIdx, SplitInfo info) {
- this.featureIdx = featureIdx;
- this.info = info;
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- return "IndexAndSplitInfo [featureIdx=" + featureIdx + ", info=" + info + ']';
- }
- }
-
- /**
- * Utility class used to build decision tree. Basically it is pointer to leaf node.
- */
- private static class TreeTip {
- /** */
- private Consumer<DecisionTreeNode> leafSetter;
-
- /** */
- private int depth;
-
- /** */
- TreeTip(Consumer<DecisionTreeNode> leafSetter, int depth) {
- this.leafSetter = leafSetter;
- this.depth = depth;
- }
- }
-
- /**
- * Utility class used as decision tree root node.
- */
- private static class RootNode implements DecisionTreeNode {
- /** */
- private DecisionTreeNode s;
-
- /**
- * {@inheritDoc}
- */
- @Override public double process(Vector v) {
- return s.process(v);
- }
-
- /** */
- void setSplit(DecisionTreeNode s) {
- this.s = s;
- }
- }
-
- /**
- * {@inheritDoc}
- */
- @Override public DecisionTreeModel train(ColumnDecisionTreeTrainerInput i) {
- prjsCache = ProjectionsCache.getOrCreate(ignite);
- IgniteCache<UUID, TrainingContext<D>> ctxtCache = ContextCache.getOrCreate(ignite);
- SplitCache.getOrCreate(ignite);
-
- UUID trainingUUID = UUID.randomUUID();
-
- TrainingContext<D> ct = new TrainingContext<>(i, continuousCalculatorProvider.apply(i), categoricalCalculatorProvider.apply(i), trainingUUID, ignite);
- ctxtCache.put(trainingUUID, ct);
-
- CacheUtils.bcast(prjsCache.getName(), ignite, () -> {
- Ignite ignite = Ignition.localIgnite();
- IgniteCache<RegionKey, List<RegionProjection>> projCache = ProjectionsCache.getOrCreate(ignite);
- IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
-
- Affinity<RegionKey> targetAffinity = ignite.affinity(ProjectionsCache.CACHE_NAME);
-
- ClusterNode locNode = ignite.cluster().localNode();
-
- Map<FeatureKey, double[]> fm = new ConcurrentHashMap<>();
- Map<RegionKey, List<RegionProjection>> pm = new ConcurrentHashMap<>();
-
- targetAffinity.
- mapKeysToNodes(IntStream.range(0, i.featuresCount()).
- mapToObj(idx -> ProjectionsCache.key(idx, 0, i.affinityKey(idx, ignite), trainingUUID)).
- collect(Collectors.toSet())).getOrDefault(locNode, Collections.emptyList()).
- forEach(k -> {
- FeatureProcessor vec;
-
- int featureIdx = k.featureIdx();
-
- IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
- TrainingContext ctx = ctxCache.get(trainingUUID);
- double[] vals = new double[ctx.labels().length];
-
- vec = ctx.featureProcessor(featureIdx);
- i.values(featureIdx).forEach(t -> vals[t.get1()] = t.get2());
-
- fm.put(getFeatureCacheKey(featureIdx, trainingUUID, i.affinityKey(featureIdx, ignite)), vals);
-
- List<RegionProjection> newReg = new ArrayList<>(BLOCK_SIZE);
- newReg.add(vec.createInitialRegion(getSamples(i.values(featureIdx), ctx.labels().length), vals, ctx.labels()));
- pm.put(k, newReg);
- });
-
- featuresCache.putAll(fm);
- projCache.putAll(pm);
-
- return null;
- });
-
- return doTrain(i, trainingUUID);
- }
-
- /**
- * Get samples array.
- *
- * @param values Stream of tuples in the form of (index, value).
- * @param size size of stream.
- * @return Samples array.
- */
- private Integer[] getSamples(Stream<IgniteBiTuple<Integer, Double>> values, int size) {
- Integer[] res = new Integer[size];
-
- values.forEach(v -> res[v.get1()] = v.get1());
-
- return res;
- }
-
- /** */
- @NotNull
- private DecisionTreeModel doTrain(ColumnDecisionTreeTrainerInput input, UUID uuid) {
- RootNode root = new RootNode();
-
- // List containing setters of leaves of the tree.
- List<TreeTip> tips = new LinkedList<>();
- tips.add(new TreeTip(root::setSplit, 0));
-
- int curDepth = 0;
- int regsCnt = 1;
-
- int featuresCnt = input.featuresCount();
- IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, input.affinityKey(fIdx, ignite), uuid)).
- forEach(k -> SplitCache.getOrCreate(ignite).put(k, new IgniteBiTuple<>(0, 0.0)));
- updateSplitCache(0, regsCnt, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
-
- // TODO: IGNITE-5893 Currently if the best split makes tree deeper than max depth process will be terminated, but actually we should
- // only stop when *any* improving split makes tree deeper than max depth. Can be fixed if we will store which
- // regions cannot be split more and split only those that can.
- while (true) {
- long before = System.currentTimeMillis();
-
- IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> b = findBestSplitIndexForFeatures(featuresCnt, input::affinityKey, uuid);
-
- long findBestRegIdx = System.currentTimeMillis() - before;
-
- Integer bestFeatureIdx = b.get1();
-
- Integer regIdx = b.get2().get1();
- Double bestInfoGain = b.get2().get2();
-
- if (regIdx >= 0 && bestInfoGain > MIN_INFO_GAIN) {
- before = System.currentTimeMillis();
-
- SplitInfo bi = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME,
- input.affinityKey(bestFeatureIdx, ignite),
- () -> {
- TrainingContext<ContinuousRegionInfo> ctx = ContextCache.getOrCreate(ignite).get(uuid);
- Ignite ignite = Ignition.localIgnite();
- RegionKey key = ProjectionsCache.key(bestFeatureIdx,
- regIdx / BLOCK_SIZE,
- input.affinityKey(bestFeatureIdx, Ignition.localIgnite()),
- uuid);
- RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
- return ctx.featureProcessor(bestFeatureIdx).findBestSplit(reg, ctx.values(bestFeatureIdx, ignite), ctx.labels(), regIdx);
- });
-
- long findBestSplit = System.currentTimeMillis() - before;
-
- IndexAndSplitInfo best = new IndexAndSplitInfo(bestFeatureIdx, bi);
-
- regsCnt++;
-
- if (log.isDebugEnabled())
- log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
- // Request bitset for split region.
- int ind = best.info.regionIndex();
-
- SparseBitSet bs = ignite.compute().affinityCall(ProjectionsCache.CACHE_NAME,
- input.affinityKey(bestFeatureIdx, ignite),
- () -> {
- Ignite ignite = Ignition.localIgnite();
- IgniteCache<FeatureKey, double[]> featuresCache = FeaturesCache.getOrCreate(ignite);
- IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ignite);
- TrainingContext ctx = ctxCache.localPeek(uuid);
-
- double[] values = featuresCache.localPeek(getFeatureCacheKey(bestFeatureIdx, uuid, input.affinityKey(bestFeatureIdx, Ignition.localIgnite())));
- RegionKey key = ProjectionsCache.key(bestFeatureIdx,
- regIdx / BLOCK_SIZE,
- input.affinityKey(bestFeatureIdx, Ignition.localIgnite()),
- uuid);
- RegionProjection reg = ProjectionsCache.getOrCreate(ignite).localPeek(key).get(regIdx % BLOCK_SIZE);
- return ctx.featureProcessor(bestFeatureIdx).calculateOwnershipBitSet(reg, values, best.info);
-
- });
-
- SplitNode sn = best.info.createSplitNode(best.featureIdx);
-
- TreeTip tipToSplit = tips.get(ind);
- tipToSplit.leafSetter.accept(sn);
- tipToSplit.leafSetter = sn::setLeft;
- int d = tipToSplit.depth++;
- tips.add(new TreeTip(sn::setRight, d));
-
- if (d > curDepth) {
- curDepth = d;
- if (log.isDebugEnabled()) {
- log.debug("Depth: " + curDepth);
- log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
- }
- }
-
- before = System.currentTimeMillis();
- // Perform split on all feature vectors.
- IgniteSupplier<Set<RegionKey>> bestRegsKeys = () -> IntStream.range(0, featuresCnt).
- mapToObj(fIdx -> ProjectionsCache.key(fIdx, ind / BLOCK_SIZE, input.affinityKey(fIdx, Ignition.localIgnite()), uuid)).
- collect(Collectors.toSet());
-
- int rc = regsCnt;
-
- // Perform split.
- CacheUtils.update(prjsCache.getName(), ignite,
- (Ignite ign, Cache.Entry<RegionKey, List<RegionProjection>> e) -> {
- RegionKey k = e.getKey();
-
- List<RegionProjection> leftBlock = e.getValue();
-
- int fIdx = k.featureIdx();
- int idxInBlock = ind % BLOCK_SIZE;
-
- IgniteCache<UUID, TrainingContext<D>> ctxCache = ContextCache.getOrCreate(ign);
- TrainingContext<D> ctx = ctxCache.get(uuid);
-
- RegionProjection targetRegProj = leftBlock.get(idxInBlock);
-
- IgniteBiTuple<RegionProjection, RegionProjection> regs = ctx.
- performSplit(input, bs, fIdx, best.featureIdx, targetRegProj, best.info.leftData(), best.info.rightData(), ign);
-
- RegionProjection left = regs.get1();
- RegionProjection right = regs.get2();
-
- leftBlock.set(idxInBlock, left);
- RegionKey rightKey = ProjectionsCache.key(fIdx, (rc - 1) / BLOCK_SIZE, input.affinityKey(fIdx, ign), uuid);
-
- IgniteCache<RegionKey, List<RegionProjection>> c = ProjectionsCache.getOrCreate(ign);
-
- List<RegionProjection> rightBlock = rightKey.equals(k) ? leftBlock : c.localPeek(rightKey);
-
- if (rightBlock == null) {
- List<RegionProjection> newBlock = new ArrayList<>(BLOCK_SIZE);
- newBlock.add(right);
- return Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, newBlock));
- }
- else {
- rightBlock.add(right);
- return rightBlock.equals(k) ?
- Stream.of(new CacheEntryImpl<>(k, leftBlock)) :
- Stream.of(new CacheEntryImpl<>(k, leftBlock), new CacheEntryImpl<>(rightKey, rightBlock));
- }
- },
- bestRegsKeys);
-
- if (log.isDebugEnabled())
- log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
-
- before = System.currentTimeMillis();
-
- updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
-
- if (log.isDebugEnabled())
- log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
- }
- else {
- if (log.isDebugEnabled())
- log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
- break;
- }
- }
-
- int rc = regsCnt;
-
- IgniteSupplier<Iterable<Cache.Entry<RegionKey, List<RegionProjection>>>> featZeroRegs = () -> {
- IgniteCache<RegionKey, List<RegionProjection>> projsCache = ProjectionsCache.getOrCreate(Ignition.localIgnite());
-
- return () -> IntStream.range(0, (rc - 1) / BLOCK_SIZE + 1).
- mapToObj(rBIdx -> ProjectionsCache.key(0, rBIdx, input.affinityKey(0, Ignition.localIgnite()), uuid)).
- map(k -> (Cache.Entry<RegionKey, List<RegionProjection>>)new CacheEntryImpl<>(k, projsCache.localPeek(k))).iterator();
- };
-
- Map<Integer, Double> vals = CacheUtils.reduce(prjsCache.getName(), ignite,
- (TrainingContext ctx, Cache.Entry<RegionKey, List<RegionProjection>> e, Map<Integer, Double> m) -> {
- int regBlockIdx = e.getKey().regionBlockIndex();
-
- if (e.getValue() != null) {
- for (int i = 0; i < e.getValue().size(); i++) {
- int regIdx = regBlockIdx * BLOCK_SIZE + i;
- RegionProjection reg = e.getValue().get(i);
-
- Double res = regCalc.apply(Arrays.stream(reg.sampleIndexes()).mapToDouble(s -> ctx.labels()[s]));
- m.put(regIdx, res);
- }
- }
-
- return m;
- },
- () -> ContextCache.getOrCreate(Ignition.localIgnite()).get(uuid),
- featZeroRegs,
- (infos, infos2) -> {
- Map<Integer, Double> res = new HashMap<>();
- res.putAll(infos);
- res.putAll(infos2);
- return res;
- },
- HashMap::new
- );
-
- int i = 0;
- for (TreeTip tip : tips) {
- tip.leafSetter.accept(new Leaf(vals.get(i)));
- i++;
- }
-
- ProjectionsCache.clear(featuresCnt, rc, input::affinityKey, uuid, ignite);
- ContextCache.getOrCreate(ignite).remove(uuid);
- FeaturesCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
- SplitCache.clear(featuresCnt, input::affinityKey, uuid, ignite);
-
- return new DecisionTreeModel(root.s);
- }
-
- /**
- * Find the best split in the form (feature index, (index of region with the best split, impurity of region with the
- * best split)).
- *
- * @param featuresCnt Count of features.
- * @param affinity Affinity function.
- * @param trainingUUID UUID of training.
- * @return Best split in the form (feature index, (index of region with the best split, impurity of region with the
- * best split)).
- */
- private IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> findBestSplitIndexForFeatures(int featuresCnt,
- IgniteBiFunction<Integer, Ignite, Object> affinity,
- UUID trainingUUID) {
- Set<Integer> featureIndexes = IntStream.range(0, featuresCnt).boxed().collect(Collectors.toSet());
-
- return CacheUtils.reduce(SplitCache.CACHE_NAME, ignite,
- (Object ctx, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e, IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>> r) ->
- Functions.MAX_GENERIC(new IgniteBiTuple<>(e.getKey().featureIdx(), e.getValue()), r, comparator()),
- () -> null,
- () -> SplitCache.localEntries(featureIndexes, affinity, trainingUUID),
- (i1, i2) -> Functions.MAX_GENERIC(i1, i2, Comparator.comparingDouble(bt -> bt.get2().get2())),
- () -> new IgniteBiTuple<>(-1, new IgniteBiTuple<>(-1, Double.NEGATIVE_INFINITY))
- );
- }
-
- /** */
- private static Comparator<IgniteBiTuple<Integer, IgniteBiTuple<Integer, Double>>> comparator() {
- return Comparator.comparingDouble(bt -> bt != null && bt.get2() != null ? bt.get2().get2() : Double.NEGATIVE_INFINITY);
- }
-
- /**
- * Update split cache.
- *
- * @param lastSplitRegionIdx Index of region which had last best split.
- * @param regsCnt Count of regions.
- * @param featuresCnt Count of features.
- * @param affinity Affinity function.
- * @param trainingUUID UUID of current training.
- */
- private void updateSplitCache(int lastSplitRegionIdx, int regsCnt, int featuresCnt,
- IgniteCurriedBiFunction<Ignite, Integer, Object> affinity,
- UUID trainingUUID) {
- CacheUtils.update(SplitCache.CACHE_NAME, ignite,
- (Ignite ign, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>> e) -> {
- Integer bestRegIdx = e.getValue().get1();
- int fIdx = e.getKey().featureIdx();
- TrainingContext ctx = ContextCache.getOrCreate(ign).get(trainingUUID);
-
- Map<Integer, RegionProjection> toCompare;
-
- // Fully recalculate best.
- if (bestRegIdx == lastSplitRegionIdx)
- toCompare = ProjectionsCache.projectionsOfFeature(fIdx, maxDepth, regsCnt, BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign);
- // Just compare previous best and two regions which are produced by split.
- else
- toCompare = ProjectionsCache.projectionsOfRegions(fIdx, maxDepth,
- IntStream.of(bestRegIdx, lastSplitRegionIdx, regsCnt - 1), BLOCK_SIZE, affinity.apply(ign), trainingUUID, ign);
-
- double[] values = ctx.values(fIdx, ign);
- double[] labels = ctx.labels();
-
- Optional<IgniteBiTuple<Integer, Double>> max = toCompare.entrySet().stream().
- map(ent -> {
- SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey());
- return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY);
- }).
- max(Comparator.comparingDouble(IgniteBiTuple::get2));
-
- return max.<Stream<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>>>
- map(objects -> Stream.of(new CacheEntryImpl<>(e.getKey(), objects))).orElseGet(Stream::empty);
- },
- () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet())
- );
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
deleted file mode 100644
index bf8790b..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainerInput.java
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased;
-
-import java.util.Map;
-import java.util.stream.Stream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.lang.IgniteBiTuple;
-
-/**
- * Input for {@link ColumnDecisionTreeTrainer}.
- */
-public interface ColumnDecisionTreeTrainerInput {
- /**
- * Projection of data on feature with the given index.
- *
- * @param idx Feature index.
- * @return Projection of data on feature with the given index.
- */
- Stream<IgniteBiTuple<Integer, Double>> values(int idx);
-
- /**
- * Labels.
- *
- * @param ignite Ignite instance.
- */
- double[] labels(Ignite ignite);
-
- /** Information about which features are categorical in the form of feature index -> number of categories. */
- Map<Integer, Integer> catFeaturesInfo();
-
- /** Number of features. */
- int featuresCount();
-
- /**
- * Get affinity key for the given column index.
- * Affinity key should be pure-functionally dependent from idx.
- */
- Object affinityKey(int idx, Ignite ignite);
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
deleted file mode 100644
index 3da6bad..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/MatrixColumnDecisionTreeTrainerInput.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.stream.DoubleStream;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
-import javax.cache.Cache;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.distributed.keys.RowColMatrixKey;
-import org.apache.ignite.ml.math.distributed.keys.impl.SparseMatrixKey;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix;
-import org.apache.ignite.ml.math.impls.storage.matrix.SparseDistributedMatrixStorage;
-import org.apache.ignite.ml.math.StorageConstants;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Adapter of SparseDistributedMatrix to ColumnDecisionTreeTrainerInput.
- * Sparse SparseDistributedMatrix should be in {@link StorageConstants#COLUMN_STORAGE_MODE} and
- * should contain samples in rows last position in row being label of this sample.
- */
-public class MatrixColumnDecisionTreeTrainerInput extends CacheColumnDecisionTreeTrainerInput<RowColMatrixKey, Map<Integer, Double>> {
- /**
- * @param m Sparse SparseDistributedMatrix should be in {@link StorageConstants#COLUMN_STORAGE_MODE}
- * containing samples in rows last position in row being label of this sample.
- * @param catFeaturesInfo Information about which features are categorical in form of feature index -> number of
- * categories.
- */
- public MatrixColumnDecisionTreeTrainerInput(SparseDistributedMatrix m, Map<Integer, Integer> catFeaturesInfo) {
- super(((SparseDistributedMatrixStorage)m.getStorage()).cache(),
- () -> Stream.of(new SparseMatrixKey(m.columnSize() - 1, m.getUUID(), m.columnSize() - 1)),
- valuesMapper(m),
- labels(m),
- keyMapper(m),
- catFeaturesInfo,
- m.columnSize() - 1,
- m.rowSize());
- }
-
- /** Values mapper. See {@link CacheColumnDecisionTreeTrainerInput#valuesMapper} */
- @NotNull
- private static IgniteFunction<Cache.Entry<RowColMatrixKey, Map<Integer, Double>>, Stream<IgniteBiTuple<Integer, Double>>> valuesMapper(
- SparseDistributedMatrix m) {
- return ent -> {
- Map<Integer, Double> map = ent.getValue() != null ? ent.getValue() : new HashMap<>();
- return IntStream.range(0, m.rowSize()).mapToObj(k -> new IgniteBiTuple<>(k, map.getOrDefault(k, 0.0)));
- };
- }
-
- /** Key mapper. See {@link CacheColumnDecisionTreeTrainerInput#keyMapper} */
- @NotNull private static IgniteFunction<Integer, Stream<RowColMatrixKey>> keyMapper(SparseDistributedMatrix m) {
- return i -> Stream.of(new SparseMatrixKey(i, ((SparseDistributedMatrixStorage)m.getStorage()).getUUID(), i));
- }
-
- /** Labels mapper. See {@link CacheColumnDecisionTreeTrainerInput#labelsMapper} */
- @NotNull private static IgniteFunction<Map<Integer, Double>, DoubleStream> labels(SparseDistributedMatrix m) {
- return mp -> IntStream.range(0, m.rowSize()).mapToDouble(k -> mp.getOrDefault(k, 0.0));
- }
-
- /** {@inheritDoc} */
- @Override public Object affinityKey(int idx, Ignite ignite) {
- return idx;
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
deleted file mode 100644
index e95f57b..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/RegionProjection.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased;
-
-import java.io.Externalizable;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import org.apache.ignite.ml.trees.RegionInfo;
-
-/**
- * Projection of region on given feature.
- *
- * @param <D> Data of region.
- */
-public class RegionProjection<D extends RegionInfo> implements Externalizable {
- /** Samples projections. */
- protected Integer[] sampleIndexes;
-
- /** Region data */
- protected D data;
-
- /** Depth of this region. */
- protected int depth;
-
- /**
- * @param sampleIndexes Samples indexes.
- * @param data Region data.
- * @param depth Depth of this region.
- */
- public RegionProjection(Integer[] sampleIndexes, D data, int depth) {
- this.data = data;
- this.depth = depth;
- this.sampleIndexes = sampleIndexes;
- }
-
- /**
- * No-op constructor used for serialization/deserialization.
- */
- public RegionProjection() {
- // No-op.
- }
-
- /**
- * Get samples indexes.
- *
- * @return Samples indexes.
- */
- public Integer[] sampleIndexes() {
- return sampleIndexes;
- }
-
- /**
- * Get region data.
- *
- * @return Region data.
- */
- public D data() {
- return data;
- }
-
- /**
- * Get region depth.
- *
- * @return Region depth.
- */
- public int depth() {
- return depth;
- }
-
- /** {@inheritDoc} */
- @Override public void writeExternal(ObjectOutput out) throws IOException {
- out.writeInt(sampleIndexes.length);
-
- for (Integer sampleIndex : sampleIndexes)
- out.writeInt(sampleIndex);
-
- out.writeObject(data);
- out.writeInt(depth);
- }
-
- /** {@inheritDoc} */
- @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- int size = in.readInt();
-
- sampleIndexes = new Integer[size];
-
- for (int i = 0; i < size; i++)
- sampleIndexes[i] = in.readInt();
-
- data = (D)in.readObject();
- depth = in.readInt();
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
deleted file mode 100644
index 6415dab..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/TrainingContext.java
+++ /dev/null
@@ -1,166 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased;
-
-import com.zaxxer.sparsebits.SparseBitSet;
-import java.util.Map;
-import java.util.UUID;
-import java.util.stream.DoubleStream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
-import org.apache.ignite.ml.trees.RegionInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.CategoricalFeatureProcessor;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousFeatureProcessor;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.FeatureProcessor;
-
-import static org.apache.ignite.ml.trees.trainers.columnbased.caches.FeaturesCache.COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME;
-
-/**
- * Context of training with {@link ColumnDecisionTreeTrainer}.
- *
- * @param <D> Class for storing of information used in calculation of impurity of continuous feature region.
- */
-public class TrainingContext<D extends ContinuousRegionInfo> {
- /** Input for training with {@link ColumnDecisionTreeTrainer}. */
- private final ColumnDecisionTreeTrainerInput input;
-
- /** Labels. */
- private final double[] labels;
-
- /** Calculator used for finding splits of region of continuous features. */
- private final ContinuousSplitCalculator<D> continuousSplitCalculator;
-
- /** Calculator used for finding splits of region of categorical feature. */
- private final IgniteFunction<DoubleStream, Double> categoricalSplitCalculator;
-
- /** UUID of current training. */
- private final UUID trainingUUID;
-
- /**
- * Construct context for training with {@link ColumnDecisionTreeTrainer}.
- *
- * @param input Input for training.
- * @param continuousSplitCalculator Calculator used for calculations of splits of continuous features regions.
- * @param categoricalSplitCalculator Calculator used for calculations of splits of categorical features regions.
- * @param trainingUUID UUID of the current training.
- * @param ignite Ignite instance.
- */
- public TrainingContext(ColumnDecisionTreeTrainerInput input,
- ContinuousSplitCalculator<D> continuousSplitCalculator,
- IgniteFunction<DoubleStream, Double> categoricalSplitCalculator,
- UUID trainingUUID,
- Ignite ignite) {
- this.input = input;
- this.labels = input.labels(ignite);
- this.continuousSplitCalculator = continuousSplitCalculator;
- this.categoricalSplitCalculator = categoricalSplitCalculator;
- this.trainingUUID = trainingUUID;
- }
-
- /**
- * Get processor used for calculating splits of categorical features.
- *
- * @param catsCnt Count of categories.
- * @return Processor used for calculating splits of categorical features.
- */
- public CategoricalFeatureProcessor categoricalFeatureProcessor(int catsCnt) {
- return new CategoricalFeatureProcessor(categoricalSplitCalculator, catsCnt);
- }
-
- /**
- * Get processor used for calculating splits of continuous features.
- *
- * @return Processor used for calculating splits of continuous features.
- */
- public ContinuousFeatureProcessor<D> continuousFeatureProcessor() {
- return new ContinuousFeatureProcessor<>(continuousSplitCalculator);
- }
-
- /**
- * Get labels.
- *
- * @return Labels.
- */
- public double[] labels() {
- return labels;
- }
-
- /**
- * Get values of feature with given index.
- *
- * @param featIdx Feature index.
- * @param ignite Ignite instance.
- * @return Values of feature with given index.
- */
- public double[] values(int featIdx, Ignite ignite) {
- IgniteCache<FeaturesCache.FeatureKey, double[]> featuresCache = ignite.getOrCreateCache(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME);
- return featuresCache.localPeek(FeaturesCache.getFeatureCacheKey(featIdx, trainingUUID, input.affinityKey(featIdx, ignite)));
- }
-
- /**
- * Perform best split on the given region projection.
- *
- * @param input Input of {@link ColumnDecisionTreeTrainer} performing split.
- * @param bitSet Bit set specifying split.
- * @param targetFeatIdx Index of feature for performing split.
- * @param bestFeatIdx Index of feature with best split.
- * @param targetRegionPrj Projection of region to split on feature with index {@code featureIdx}.
- * @param leftData Data of left region of split.
- * @param rightData Data of right region of split.
- * @param ignite Ignite instance.
- * @return Perform best split on the given region projection.
- */
- public IgniteBiTuple<RegionProjection, RegionProjection> performSplit(ColumnDecisionTreeTrainerInput input,
- SparseBitSet bitSet, int targetFeatIdx, int bestFeatIdx, RegionProjection targetRegionPrj, RegionInfo leftData,
- RegionInfo rightData, Ignite ignite) {
-
- Map<Integer, Integer> catFeaturesInfo = input.catFeaturesInfo();
-
- if (!catFeaturesInfo.containsKey(targetFeatIdx) && !catFeaturesInfo.containsKey(bestFeatIdx))
- return continuousFeatureProcessor().performSplit(bitSet, targetRegionPrj, (D)leftData, (D)rightData);
- else if (catFeaturesInfo.containsKey(targetFeatIdx))
- return categoricalFeatureProcessor(catFeaturesInfo.get(targetFeatIdx)).performSplitGeneric(bitSet, values(targetFeatIdx, ignite), targetRegionPrj, leftData, rightData);
- return continuousFeatureProcessor().performSplitGeneric(bitSet, labels, targetRegionPrj, leftData, rightData);
- }
-
- /**
- * Processor used for calculating splits for feature with the given index.
- *
- * @param featureIdx Index of feature to process.
- * @return Processor used for calculating splits for feature with the given index.
- */
- public FeatureProcessor featureProcessor(int featureIdx) {
- return input.catFeaturesInfo().containsKey(featureIdx) ? categoricalFeatureProcessor(input.catFeaturesInfo().get(featureIdx)) : continuousFeatureProcessor();
- }
-
- /**
- * Shortcut for affinity key.
- *
- * @param idx Feature index.
- * @return Affinity key.
- */
- public Object affinityKey(int idx) {
- return input.affinityKey(idx, Ignition.localIgnite());
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
deleted file mode 100644
index 51ea359..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ContextCache.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.caches;
-
-import java.util.UUID;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.TrainingContext;
-
-/**
- * Class for operations related to cache containing training context for {@link ColumnDecisionTreeTrainer}.
- */
-public class ContextCache {
- /**
- * Name of cache containing training context for {@link ColumnDecisionTreeTrainer}.
- */
- public static final String COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME";
-
- /**
- * Get or create cache for training context.
- *
- * @param ignite Ignite instance.
- * @param <D> Class storing information about continuous regions.
- * @return Cache for training context.
- */
- public static <D extends ContinuousRegionInfo> IgniteCache<UUID, TrainingContext<D>> getOrCreate(Ignite ignite) {
- CacheConfiguration<UUID, TrainingContext<D>> cfg = new CacheConfiguration<>();
-
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.FULL_SYNC);
-
- cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
- cfg.setEvictionPolicy(null);
-
- cfg.setCopyOnRead(false);
-
- cfg.setCacheMode(CacheMode.REPLICATED);
-
- cfg.setOnheapCacheEnabled(true);
-
- cfg.setReadFromBackup(true);
-
- cfg.setName(COLUMN_DECISION_TREE_TRAINER_CONTEXT_CACHE_NAME);
-
- return ignite.getOrCreateCache(cfg);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java
deleted file mode 100644
index fcc1f16..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/FeaturesCache.java
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.caches;
-
-import java.util.Set;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.cache.affinity.AffinityKeyMapped;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-
-/**
- * Cache storing features for {@link ColumnDecisionTreeTrainer}.
- */
-public class FeaturesCache {
- /**
- * Name of cache which is used for storing features for {@link ColumnDecisionTreeTrainer}.
- */
- public static final String COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME";
-
- /**
- * Key of features cache.
- */
- public static class FeatureKey {
- /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */
- @AffinityKeyMapped
- private Object parentColKey;
-
- /** Index of feature. */
- private final int featureIdx;
-
- /** UUID of training. */
- private final UUID trainingUUID;
-
- /**
- * Construct FeatureKey.
- *
- * @param featureIdx Feature index.
- * @param trainingUUID UUID of training.
- * @param parentColKey Column key of cache used as input.
- */
- public FeatureKey(int featureIdx, UUID trainingUUID, Object parentColKey) {
- this.parentColKey = parentColKey;
- this.featureIdx = featureIdx;
- this.trainingUUID = trainingUUID;
- this.parentColKey = parentColKey;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- FeatureKey key = (FeatureKey)o;
-
- if (featureIdx != key.featureIdx)
- return false;
- return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null;
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = trainingUUID != null ? trainingUUID.hashCode() : 0;
- res = 31 * res + featureIdx;
- return res;
- }
- }
-
- /**
- * Create new projections cache for ColumnDecisionTreeTrainer if needed.
- *
- * @param ignite Ignite instance.
- */
- public static IgniteCache<FeatureKey, double[]> getOrCreate(Ignite ignite) {
- CacheConfiguration<FeatureKey, double[]> cfg = new CacheConfiguration<>();
-
- // Write to primary.
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
- // Atomic transactions only.
- cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
- // No eviction.
- cfg.setEvictionPolicy(null);
-
- // No copying of values.
- cfg.setCopyOnRead(false);
-
- // Cache is partitioned.
- cfg.setCacheMode(CacheMode.PARTITIONED);
-
- cfg.setOnheapCacheEnabled(true);
-
- cfg.setBackups(0);
-
- cfg.setName(COLUMN_DECISION_TREE_TRAINER_FEATURES_CACHE_NAME);
-
- return ignite.getOrCreateCache(cfg);
- }
-
- /**
- * Construct FeatureKey from index, uuid and affinity key.
- *
- * @param idx Feature index.
- * @param uuid UUID of training.
- * @param aff Affinity key.
- * @return FeatureKey.
- */
- public static FeatureKey getFeatureCacheKey(int idx, UUID uuid, Object aff) {
- return new FeatureKey(idx, uuid, aff);
- }
-
- /**
- * Clear all data from features cache related to given training.
- *
- * @param featuresCnt Count of features.
- * @param affinity Affinity function.
- * @param uuid Training uuid.
- * @param ignite Ignite instance.
- */
- public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid,
- Ignite ignite) {
- Set<FeatureKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> getFeatureCacheKey(fIdx, uuid, affinity.apply(fIdx, ignite))).collect(Collectors.toSet());
-
- getOrCreate(ignite).removeAll(toRmv);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java
deleted file mode 100644
index 080cb66..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/ProjectionsCache.java
+++ /dev/null
@@ -1,286 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.caches;
-
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.PrimitiveIterator;
-import java.util.Set;
-import java.util.UUID;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.cache.affinity.AffinityKeyMapped;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-import org.apache.ignite.ml.trees.trainers.columnbased.RegionProjection;
-
-/**
- * Cache used for storing data of region projections on features.
- */
-public class ProjectionsCache {
- /**
- * Name of cache which is used for storing data of region projections on features of {@link
- * ColumnDecisionTreeTrainer}.
- */
- public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_PROJECTIONS_CACHE_NAME";
-
- /**
- * Key of region projections cache.
- */
- public static class RegionKey {
- /** Column key of cache used as input for {@link ColumnDecisionTreeTrainer}. */
- @AffinityKeyMapped
- private final Object parentColKey;
-
- /** Feature index. */
- private final int featureIdx;
-
- /** Region index. */
- private final int regBlockIdx;
-
- /** Training UUID. */
- private final UUID trainingUUID;
-
- /**
- * Construct a RegionKey from feature index, index of block, key of column in input cache and UUID of training.
- *
- * @param featureIdx Feature index.
- * @param regBlockIdx Index of block.
- * @param parentColKey Key of column in input cache.
- * @param trainingUUID UUID of training.
- */
- public RegionKey(int featureIdx, int regBlockIdx, Object parentColKey, UUID trainingUUID) {
- this.featureIdx = featureIdx;
- this.regBlockIdx = regBlockIdx;
- this.trainingUUID = trainingUUID;
- this.parentColKey = parentColKey;
- }
-
- /**
- * Feature index.
- *
- * @return Feature index.
- */
- public int featureIdx() {
- return featureIdx;
- }
-
- /**
- * Region block index.
- *
- * @return Region block index.
- */
- public int regionBlockIndex() {
- return regBlockIdx;
- }
-
- /**
- * UUID of training.
- *
- * @return UUID of training.
- */
- public UUID trainingUUID() {
- return trainingUUID;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- RegionKey key = (RegionKey)o;
-
- if (featureIdx != key.featureIdx)
- return false;
- if (regBlockIdx != key.regBlockIdx)
- return false;
- return trainingUUID != null ? trainingUUID.equals(key.trainingUUID) : key.trainingUUID == null;
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = trainingUUID != null ? trainingUUID.hashCode() : 0;
- res = 31 * res + featureIdx;
- res = 31 * res + regBlockIdx;
- return res;
- }
-
- /** {@inheritDoc} */
- @Override public String toString() {
- return "RegionKey [" +
- "parentColKey=" + parentColKey +
- ", featureIdx=" + featureIdx +
- ", regBlockIdx=" + regBlockIdx +
- ", trainingUUID=" + trainingUUID +
- ']';
- }
- }
-
- /**
- * Affinity service for region projections cache.
- *
- * @return Affinity service for region projections cache.
- */
- public static Affinity<RegionKey> affinity() {
- return Ignition.localIgnite().affinity(CACHE_NAME);
- }
-
- /**
- * Get or create region projections cache.
- *
- * @param ignite Ignite instance.
- * @return Region projections cache.
- */
- public static IgniteCache<RegionKey, List<RegionProjection>> getOrCreate(Ignite ignite) {
- CacheConfiguration<RegionKey, List<RegionProjection>> cfg = new CacheConfiguration<>();
-
- // Write to primary.
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
- // Atomic transactions only.
- cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
- // No eviction.
- cfg.setEvictionPolicy(null);
-
- // No copying of values.
- cfg.setCopyOnRead(false);
-
- // Cache is partitioned.
- cfg.setCacheMode(CacheMode.PARTITIONED);
-
- cfg.setBackups(0);
-
- cfg.setOnheapCacheEnabled(true);
-
- cfg.setName(CACHE_NAME);
-
- return ignite.getOrCreateCache(cfg);
- }
-
- /**
- * Get region projections in the form of map (regionIndex -> regionProjections).
- *
- * @param featureIdx Feature index.
- * @param maxDepth Max depth of decision tree.
- * @param regionIndexes Indexes of regions for which we want get projections.
- * @param blockSize Size of regions block.
- * @param affinity Affinity function.
- * @param trainingUUID UUID of training.
- * @param ignite Ignite instance.
- * @return Region projections in the form of map (regionIndex -> regionProjections).
- */
- public static Map<Integer, RegionProjection> projectionsOfRegions(int featureIdx, int maxDepth,
- IntStream regionIndexes, int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID,
- Ignite ignite) {
- HashMap<Integer, RegionProjection> regsForSearch = new HashMap<>();
- IgniteCache<RegionKey, List<RegionProjection>> cache = getOrCreate(ignite);
-
- PrimitiveIterator.OfInt itr = regionIndexes.iterator();
-
- int curBlockIdx = -1;
- List<RegionProjection> block = null;
-
- Object affinityKey = affinity.apply(featureIdx);
-
- while (itr.hasNext()) {
- int i = itr.nextInt();
-
- int blockIdx = i / blockSize;
-
- if (blockIdx != curBlockIdx) {
- block = cache.localPeek(key(featureIdx, blockIdx, affinityKey, trainingUUID));
- curBlockIdx = blockIdx;
- }
-
- if (block == null)
- throw new IllegalStateException("Unexpected null block at index " + i);
-
- RegionProjection reg = block.get(i % blockSize);
-
- if (reg.depth() < maxDepth)
- regsForSearch.put(i, reg);
- }
-
- return regsForSearch;
- }
-
- /**
- * Returns projections of regions on given feature filtered by maximal depth in the form of (region index -> region
- * projection).
- *
- * @param featureIdx Feature index.
- * @param maxDepth Maximal depth of the tree.
- * @param regsCnt Count of regions.
- * @param blockSize Size of regions blocks.
- * @param affinity Affinity function.
- * @param trainingUUID UUID of training.
- * @param ignite Ignite instance.
- * @return Projections of regions on given feature filtered by maximal depth in the form of (region index -> region
- * projection).
- */
- public static Map<Integer, RegionProjection> projectionsOfFeature(int featureIdx, int maxDepth, int regsCnt,
- int blockSize, IgniteFunction<Integer, Object> affinity, UUID trainingUUID, Ignite ignite) {
- return projectionsOfRegions(featureIdx, maxDepth, IntStream.range(0, regsCnt), blockSize, affinity, trainingUUID, ignite);
- }
-
- /**
- * Construct key for projections cache.
- *
- * @param featureIdx Feature index.
- * @param regBlockIdx Region block index.
- * @param parentColKey Column key of cache used as input for {@link ColumnDecisionTreeTrainer}.
- * @param uuid UUID of training.
- * @return Key for projections cache.
- */
- public static RegionKey key(int featureIdx, int regBlockIdx, Object parentColKey, UUID uuid) {
- return new RegionKey(featureIdx, regBlockIdx, parentColKey, uuid);
- }
-
- /**
- * Clear data from projections cache related to given training.
- *
- * @param featuresCnt Features count.
- * @param regs Regions count.
- * @param aff Affinity function.
- * @param uuid UUID of training.
- * @param ignite Ignite instance.
- */
- public static void clear(int featuresCnt, int regs, IgniteBiFunction<Integer, Ignite, Object> aff, UUID uuid,
- Ignite ignite) {
- Set<RegionKey> toRmv = IntStream.range(0, featuresCnt).boxed().
- flatMap(fIdx -> IntStream.range(0, regs).boxed().map(reg -> new IgniteBiTuple<>(fIdx, reg))).
- map(t -> key(t.get1(), t.get2(), aff.apply(t.get1(), ignite), uuid)).
- collect(Collectors.toSet());
-
- getOrCreate(ignite).removeAll(toRmv);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java
deleted file mode 100644
index ecbc861..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/SplitCache.java
+++ /dev/null
@@ -1,206 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.caches;
-
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Set;
-import java.util.UUID;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import javax.cache.Cache;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.CacheAtomicityMode;
-import org.apache.ignite.cache.CacheMode;
-import org.apache.ignite.cache.CacheWriteSynchronizationMode;
-import org.apache.ignite.cache.affinity.Affinity;
-import org.apache.ignite.cache.affinity.AffinityKeyMapped;
-import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
-import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer;
-
-/**
- * Class for working with cache used for storing of best splits during training with {@link ColumnDecisionTreeTrainer}.
- */
-public class SplitCache {
- /** Name of splits cache. */
- public static final String CACHE_NAME = "COLUMN_DECISION_TREE_TRAINER_SPLIT_CACHE_NAME";
-
- /**
- * Class used for keys in the splits cache.
- */
- public static class SplitKey {
- /** UUID of current training. */
- private final UUID trainingUUID;
-
- /** Affinity key of input data. */
- @AffinityKeyMapped
- private final Object parentColKey;
-
- /** Index of feature by which the split is made. */
- private final int featureIdx;
-
- /**
- * Construct SplitKey.
- *
- * @param trainingUUID UUID of the training.
- * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node
- * as column with that feature in input.
- * @param featureIdx Feature index.
- */
- public SplitKey(UUID trainingUUID, Object parentColKey, int featureIdx) {
- this.trainingUUID = trainingUUID;
- this.featureIdx = featureIdx;
- this.parentColKey = parentColKey;
- }
-
- /** Get UUID of current training. */
- public UUID trainingUUID() {
- return trainingUUID;
- }
-
- /**
- * Get feature index.
- *
- * @return Feature index.
- */
- public int featureIdx() {
- return featureIdx;
- }
-
- /** {@inheritDoc} */
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- SplitKey splitKey = (SplitKey)o;
-
- if (featureIdx != splitKey.featureIdx)
- return false;
- return trainingUUID != null ? trainingUUID.equals(splitKey.trainingUUID) : splitKey.trainingUUID == null;
-
- }
-
- /** {@inheritDoc} */
- @Override public int hashCode() {
- int res = trainingUUID != null ? trainingUUID.hashCode() : 0;
- res = 31 * res + featureIdx;
- return res;
- }
- }
-
- /**
- * Construct the key for splits cache.
- *
- * @param featureIdx Feature index.
- * @param parentColKey Affinity key used to ensure that cache entry for given feature will be on the same node as
- * column with that feature in input.
- * @param uuid UUID of current training.
- * @return Key for splits cache.
- */
- public static SplitKey key(int featureIdx, Object parentColKey, UUID uuid) {
- return new SplitKey(uuid, parentColKey, featureIdx);
- }
-
- /**
- * Get or create splits cache.
- *
- * @param ignite Ignite instance.
- * @return Splits cache.
- */
- public static IgniteCache<SplitKey, IgniteBiTuple<Integer, Double>> getOrCreate(Ignite ignite) {
- CacheConfiguration<SplitKey, IgniteBiTuple<Integer, Double>> cfg = new CacheConfiguration<>();
-
- // Write to primary.
- cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC);
-
- // Atomic transactions only.
- cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC);
-
- // No eviction.
- cfg.setEvictionPolicy(null);
-
- // No copying of values.
- cfg.setCopyOnRead(false);
-
- // Cache is partitioned.
- cfg.setCacheMode(CacheMode.PARTITIONED);
-
- cfg.setBackups(0);
-
- cfg.setOnheapCacheEnabled(true);
-
- cfg.setName(CACHE_NAME);
-
- return ignite.getOrCreateCache(cfg);
- }
-
- /**
- * Affinity function used in splits cache.
- *
- * @return Affinity function used in splits cache.
- */
- public static Affinity<SplitKey> affinity() {
- return Ignition.localIgnite().affinity(CACHE_NAME);
- }
-
- /**
- * Returns local entries for keys corresponding to {@code featureIndexes}.
- *
- * @param featureIndexes Index of features.
- * @param affinity Affinity function.
- * @param trainingUUID UUID of training.
- * @return local entries for keys corresponding to {@code featureIndexes}.
- */
- public static Iterable<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> localEntries(
- Set<Integer> featureIndexes,
- IgniteBiFunction<Integer, Ignite, Object> affinity,
- UUID trainingUUID) {
- Ignite ignite = Ignition.localIgnite();
- Set<SplitKey> keys = featureIndexes.stream().map(fIdx -> new SplitKey(trainingUUID, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet());
-
- Collection<SplitKey> locKeys = affinity().mapKeysToNodes(keys).getOrDefault(ignite.cluster().localNode(), Collections.emptyList());
-
- return () -> {
- Function<SplitKey, Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>> f = k -> (new CacheEntryImpl<>(k, getOrCreate(ignite).localPeek(k)));
- return locKeys.stream().map(f).iterator();
- };
- }
-
- /**
- * Clears data related to current training from splits cache related to given training.
- *
- * @param featuresCnt Count of features.
- * @param affinity Affinity function.
- * @param uuid UUID of the given training.
- * @param ignite Ignite instance.
- */
- public static void clear(int featuresCnt, IgniteBiFunction<Integer, Ignite, Object> affinity, UUID uuid,
- Ignite ignite) {
- Set<SplitKey> toRmv = IntStream.range(0, featuresCnt).boxed().map(fIdx -> new SplitKey(uuid, affinity.apply(fIdx, ignite), fIdx)).collect(Collectors.toSet());
-
- getOrCreate(ignite).removeAll(toRmv);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java
deleted file mode 100644
index 0a488ab..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/caches/package-info.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/**
- * <!-- Package description. -->
- * Contains cache configurations for columnbased decision tree trainer with some related logic.
- */
-package org.apache.ignite.ml.trees.trainers.columnbased.caches;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java
deleted file mode 100644
index 9fd4c66..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/ContinuousSplitCalculators.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs;
-
-import org.apache.ignite.Ignite;
-import org.apache.ignite.ml.math.functions.IgniteCurriedBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainerInput;
-
-/** Continuous Split Calculators. */
-public class ContinuousSplitCalculators {
- /** Variance split calculator. */
- public static IgniteFunction<ColumnDecisionTreeTrainerInput, VarianceSplitCalculator> VARIANCE = input ->
- new VarianceSplitCalculator();
-
- /** Gini split calculator. */
- public static IgniteCurriedBiFunction<Ignite, ColumnDecisionTreeTrainerInput, GiniSplitCalculator> GINI = ignite ->
- input -> new GiniSplitCalculator(input.labels(ignite));
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/139c2af6/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java
deleted file mode 100644
index 259c84c..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/contsplitcalcs/GiniSplitCalculator.java
+++ /dev/null
@@ -1,234 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs;
-
-import it.unimi.dsi.fastutil.doubles.Double2IntArrayMap;
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.PrimitiveIterator;
-import java.util.stream.DoubleStream;
-import org.apache.ignite.ml.trees.ContinuousRegionInfo;
-import org.apache.ignite.ml.trees.ContinuousSplitCalculator;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.ContinuousSplitInfo;
-import org.apache.ignite.ml.trees.trainers.columnbased.vectors.SplitInfo;
-
-/**
- * Calculator for Gini impurity.
- */
-public class GiniSplitCalculator implements ContinuousSplitCalculator<GiniSplitCalculator.GiniData> {
- /** Mapping assigning index to each member value */
- private final Map<Double, Integer> mapping = new Double2IntArrayMap();
-
- /**
- * Create Gini split calculator from labels.
- *
- * @param labels Labels.
- */
- public GiniSplitCalculator(double[] labels) {
- int i = 0;
-
- for (double label : labels) {
- if (!mapping.containsKey(label)) {
- mapping.put(label, i);
- i++;
- }
- }
- }
-
- /** {@inheritDoc} */
- @Override public GiniData calculateRegionInfo(DoubleStream s, int l) {
- PrimitiveIterator.OfDouble itr = s.iterator();
-
- Map<Double, Integer> m = new HashMap<>();
-
- int size = 0;
-
- while (itr.hasNext()) {
- size++;
- m.compute(itr.next(), (a, i) -> i != null ? i + 1 : 1);
- }
-
- double c2 = m.values().stream().mapToDouble(v -> v * v).sum();
-
- int[] cnts = new int[mapping.size()];
-
- m.forEach((key, value) -> cnts[mapping.get(key)] = value);
-
- return new GiniData(size != 0 ? 1 - c2 / (size * size) : 0.0, size, cnts, c2);
- }
-
- /** {@inheritDoc} */
- @Override public SplitInfo<GiniData> splitRegion(Integer[] s, double[] values, double[] labels, int regionIdx,
- GiniData d) {
- int size = d.getSize();
-
- double lg = 0.0;
- double rg = d.impurity();
-
- double lc2 = 0.0;
- double rc2 = d.c2;
- int lSize = 0;
-
- double minImpurity = d.impurity() * size;
- double curThreshold;
- double curImpurity;
- double threshold = Double.NEGATIVE_INFINITY;
-
- int i = 0;
- int nextIdx = s[0];
- i++;
- double[] lrImps = new double[] {0.0, d.impurity(), lc2, rc2};
-
- int[] lMapCur = new int[d.counts().length];
- int[] rMapCur = new int[d.counts().length];
-
- System.arraycopy(d.counts(), 0, rMapCur, 0, d.counts().length);
-
- int[] lMap = new int[d.counts().length];
- int[] rMap = new int[d.counts().length];
-
- System.arraycopy(d.counts(), 0, rMap, 0, d.counts().length);
-
- do {
- // Process all values equal to prev.
- while (i < s.length) {
- moveLeft(labels[nextIdx], i, size - i, lMapCur, rMapCur, lrImps);
- curImpurity = (i * lrImps[0] + (size - i) * lrImps[1]);
- curThreshold = values[nextIdx];
-
- if (values[nextIdx] != values[(nextIdx = s[i++])]) {
- if (curImpurity < minImpurity) {
- lSize = i - 1;
-
- lg = lrImps[0];
- rg = lrImps[1];
-
- lc2 = lrImps[2];
- rc2 = lrImps[3];
-
- System.arraycopy(lMapCur, 0, lMap, 0, lMapCur.length);
- System.arraycopy(rMapCur, 0, rMap, 0, rMapCur.length);
-
- minImpurity = curImpurity;
- threshold = curThreshold;
- }
-
- break;
- }
- }
- }
- while (i < s.length - 1);
-
- if (lSize == size || lSize == 0)
- return null;
-
- GiniData lData = new GiniData(lg, lSize, lMap, lc2);
- int rSize = size - lSize;
- GiniData rData = new GiniData(rg, rSize, rMap, rc2);
-
- return new ContinuousSplitInfo<>(regionIdx, threshold, lData, rData);
- }
-
- /**
- * Add point to the left interval and remove it from the right interval and calculate necessary statistics on
- * intervals with new bounds.
- */
- private void moveLeft(double x, int lSize, int rSize, int[] lMap, int[] rMap, double[] data) {
- double lc2 = data[2];
- double rc2 = data[3];
-
- Integer idx = mapping.get(x);
-
- int cxl = lMap[idx];
- int cxr = rMap[idx];
-
- lc2 += 2 * cxl + 1;
- rc2 -= 2 * cxr - 1;
-
- lMap[idx] += 1;
- rMap[idx] -= 1;
-
- data[0] = 1 - lc2 / (lSize * lSize);
- data[1] = 1 - rc2 / (rSize * rSize);
-
- data[2] = lc2;
- data[3] = rc2;
- }
-
- /**
- * Data used for gini impurity calculations.
- */
- public static class GiniData extends ContinuousRegionInfo {
- /** Sum of squares of counts of each label. */
- private double c2;
-
- /** Counts of each label. On i-th position there is count of label which is mapped to index i. */
- private int[] m;
-
- /**
- * Create Gini data.
- *
- * @param impurity Impurity (i.e. Gini impurity).
- * @param size Count of samples.
- * @param m Counts of each label.
- * @param c2 Sum of squares of counts of each label.
- */
- public GiniData(double impurity, int size, int[] m, double c2) {
- super(impurity, size);
- this.m = m;
- this.c2 = c2;
- }
-
- /**
- * No-op constructor for serialization/deserialization..
- */
- public GiniData() {
- // No-op.
- }
-
- /** Get counts of each label. */
- public int[] counts() {
- return m;
- }
-
- /** {@inheritDoc} */
- @Override public void writeExternal(ObjectOutput out) throws IOException {
- super.writeExternal(out);
- out.writeDouble(c2);
- out.writeInt(m.length);
- for (int i : m)
- out.writeInt(i);
-
- }
-
- /** {@inheritDoc} */
- @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
- super.readExternal(in);
-
- c2 = in.readDouble();
- int size = in.readInt();
- m = new int[size];
-
- for (int i = 0; i < size; i++)
- m[i] = in.readInt();
- }
- }
-}