You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/04 07:52:40 UTC

[2/9] ignite git commit: IGNITE-7007: Decision tree code cleanup

IGNITE-7007: Decision tree code cleanup

This closes #3084


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/a29fe352
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/a29fe352
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/a29fe352

Branch: refs/heads/ignite-zk
Commit: a29fe352de4fa3f66f471a4315fff097fe06c786
Parents: 3979e6a
Author: artemmalykh <am...@gridgain.com>
Authored: Fri Dec 1 20:54:59 2017 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Dec 1 20:54:59 2017 +0300

----------------------------------------------------------------------
 .../ignite/ml/math/distributed/CacheUtils.java  |  2 --
 .../columnbased/ColumnDecisionTreeTrainer.java  | 33 +++++++++++++-------
 .../org/apache/ignite/ml/util/MnistUtils.java   | 17 +++++-----
 .../java/org/apache/ignite/ml/util/Utils.java   |  6 ++--
 .../ml/trees/ColumnDecisionTreeTrainerTest.java |  3 +-
 .../ColumnDecisionTreeTrainerBenchmark.java     | 31 +++++++++---------
 6 files changed, 51 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
index 6baa865..9ca167c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
@@ -484,9 +484,7 @@ public class CacheUtils {
                 m.put(k, v);
             }
 
-            long before = System.currentTimeMillis();
             cache.putAll(m);
-            System.out.println("PutAll took: " + (System.currentTimeMillis() - before));
         });
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
index 32e33f3..fec0a83 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
@@ -26,6 +26,7 @@ import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
@@ -37,12 +38,12 @@ import java.util.stream.Stream;
 import javax.cache.Cache;
 import org.apache.ignite.Ignite;
 import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteLogger;
 import org.apache.ignite.Ignition;
 import org.apache.ignite.cache.CachePeekMode;
 import org.apache.ignite.cache.affinity.Affinity;
 import org.apache.ignite.cluster.ClusterNode;
 import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
-import org.apache.ignite.internal.util.typedef.X;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.Trainer;
 import org.apache.ignite.ml.math.Vector;
@@ -115,6 +116,9 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
     /** Ignite instance. */
     private final Ignite ignite;
 
+    /** Logger */
+    private final IgniteLogger log;
+
     /**
      * Construct {@link ColumnDecisionTreeTrainer}.
      *
@@ -135,6 +139,7 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
         this.categoricalCalculatorProvider = categoricalCalculatorProvider;
         this.regCalc = regCalc;
         this.ignite = ignite;
+        this.log = ignite.log();
     }
 
     /**
@@ -329,7 +334,8 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
 
                 regsCnt++;
 
-                X.println(">>> Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
+                if (log.isDebugEnabled())
+                    log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
                 // Request bitset for split region.
                 int ind = best.info.regionIndex();
 
@@ -361,8 +367,10 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
 
                 if (d > curDepth) {
                     curDepth = d;
-                    X.println(">>> Depth: " + curDepth);
-                    X.println(">>> Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
+                    if (log.isDebugEnabled()) {
+                        log.debug("Depth: " + curDepth);
+                        log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
+                    }
                 }
 
                 before = System.currentTimeMillis();
@@ -415,16 +423,19 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
                     },
                     bestRegsKeys);
 
-                X.println(">>> Update of projs cache took " + (System.currentTimeMillis() - before));
+                if (log.isDebugEnabled())
+                    log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
 
                 before = System.currentTimeMillis();
 
                 updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
 
-                X.println(">>> Update of split cache took " + (System.currentTimeMillis() - before));
+                if (log.isDebugEnabled())
+                    log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
             }
             else {
-                X.println(">>> Best feature index: " + bestFeatureIdx + ", best infoGain " + bestInfoGain);
+                if (log.isDebugEnabled())
+                    log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
                 break;
             }
         }
@@ -541,15 +552,15 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
                 double[] values = ctx.values(fIdx, ign);
                 double[] labels = ctx.labels();
 
-                IgniteBiTuple<Integer, Double> max = toCompare.entrySet().stream().
+                Optional<IgniteBiTuple<Integer, Double>> max = toCompare.entrySet().stream().
                     map(ent -> {
                         SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey());
                         return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY);
                     }).
-                    max(Comparator.comparingDouble(IgniteBiTuple::get2)).
-                    get();
+                    max(Comparator.comparingDouble(IgniteBiTuple::get2));
 
-                return Stream.of(new CacheEntryImpl<>(e.getKey(), max));
+                return max.<Stream<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>>>
+                    map(objects -> Stream.of(new CacheEntryImpl<>(e.getKey(), objects))).orElseGet(Stream::empty);
             },
             () -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet())
         );

http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
index d69781e..a3f1d21 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
@@ -25,6 +25,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Random;
 import java.util.stream.Stream;
+import org.apache.ignite.IgniteException;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 
 /**
@@ -40,14 +41,14 @@ public class MnistUtils {
      * @param rnd Random numbers generatror.
      * @param cnt Count of samples to read.
      * @return Stream of MNIST samples.
-     * @throws IOException
+     * @throws IOException In case of exception.
      */
     public static Stream<DenseLocalOnHeapVector> mnist(String imagesPath, String labelsPath, Random rnd, int cnt)
         throws IOException {
         FileInputStream isImages = new FileInputStream(imagesPath);
         FileInputStream isLabels = new FileInputStream(labelsPath);
 
-        int magic = read4Bytes(isImages); // Skip magic number.
+        read4Bytes(isImages); // Skip magic number.
         int numOfImages = read4Bytes(isImages);
         int imgHeight = read4Bytes(isImages);
         int imgWidth = read4Bytes(isImages);
@@ -57,10 +58,6 @@ public class MnistUtils {
 
         int numOfPixels = imgHeight * imgWidth;
 
-        System.out.println("Magic: " + magic);
-        System.out.println("Num of images: " + numOfImages);
-        System.out.println("Num of pixels: " + numOfPixels);
-
         double[][] vecs = new double[numOfImages][numOfPixels + 1];
 
         for (int imgNum = 0; imgNum < numOfImages; imgNum++) {
@@ -88,7 +85,7 @@ public class MnistUtils {
      * @param outPath Path to output path.
      * @param rnd Random numbers generator.
      * @param cnt Count of samples to read.
-     * @throws IOException
+     * @throws IOException In case of exception.
      */
     public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int cnt)
         throws IOException {
@@ -109,7 +106,7 @@ public class MnistUtils {
 
                 }
                 catch (IOException e) {
-                    e.printStackTrace();
+                    throw new IgniteException("Error while converting to LIBSVM.");
                 }
             });
         }
@@ -119,9 +116,9 @@ public class MnistUtils {
      * Utility method for reading 4 bytes from input stream.
      *
      * @param is Input stream.
-     * @throws IOException
+     * @throws IOException In case of exception.
      */
     private static int read4Bytes(FileInputStream is) throws IOException {
         return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read());
     }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index bb779e3..847b1f1 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import org.apache.ignite.IgniteException;
 
 /**
  * Class with various utility methods.
@@ -34,8 +35,9 @@ public class Utils {
      * @param <T> Class of original object;
      * @return Deep copy of original object.
      */
+    @SuppressWarnings({"unchecked"})
     public static <T> T copy(T orig) {
-        Object obj = null;
+        Object obj;
 
         try {
             ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -50,7 +52,7 @@ public class Utils {
             obj = in.readObject();
         }
         catch (IOException | ClassNotFoundException e) {
-            e.printStackTrace();
+            throw new IgniteException("Couldn't copy the object.");
         }
 
         return (T)obj;

http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
index 2b03b47..929ded9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
@@ -26,6 +26,7 @@ import java.util.Random;
 import java.util.stream.Collectors;
 import java.util.stream.DoubleStream;
 import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.internal.util.typedef.X;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.math.StorageConstants;
 import org.apache.ignite.ml.math.Tracer;
@@ -183,7 +184,7 @@ public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest {
         byRegion.keySet().forEach(k -> {
             LabeledVectorDouble sp = byRegion.get(k).get(0);
             Tracer.showAscii(sp.vector());
-            System.out.println("Act: " + sp.label() + " " + " pred: " + mdl.predict(sp.vector()));
+            X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.predict(sp.vector()) + "]");
             assert mdl.predict(sp.vector()) == sp.doubleLabel();
         });
     }

http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
index 4e7cc24..7ca5d38 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
@@ -45,6 +45,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
 import org.apache.ignite.configuration.IgniteConfiguration;
 import org.apache.ignite.internal.processors.cache.GridCacheProcessor;
 import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.internal.util.typedef.X;
 import org.apache.ignite.lang.IgniteBiTuple;
 import org.apache.ignite.ml.Model;
 import org.apache.ignite.ml.estimators.Estimators;
@@ -163,14 +164,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
         ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
             new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
 
-        System.out.println(">>> Training started");
+        X.println("Training started.");
         long before = System.currentTimeMillis();
         DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
-        System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+        X.println("Training finished in " + (System.currentTimeMillis() - before));
 
         IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
         Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
-        System.out.println(">>> Errs percentage: " + accuracy);
+        X.println("Errors percentage: " + accuracy);
 
         Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
         Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
@@ -204,14 +205,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
         ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
             new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
 
-        System.out.println(">>> Training started");
+        X.println("Training started");
         long before = System.currentTimeMillis();
         DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
-        System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+        X.println("Training finished in " + (System.currentTimeMillis() - before));
 
         IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
         Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
-        System.out.println(">>> Errs percentage: " + accuracy);
+        X.println("Errors percentage: " + accuracy);
 
         Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
         Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
@@ -252,10 +253,10 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
 
         SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage();
         long before = System.currentTimeMillis();
-        System.out.println(">>> Batch loading started...");
+        X.println("Batch loading started...");
         loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen.
             points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1);
-        System.out.println(">>> Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
+        X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
 
         for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
             byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
@@ -268,12 +269,12 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
         before = System.currentTimeMillis();
         DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
 
-        System.out.println(">>> Took time(ms): " + (System.currentTimeMillis() - before));
+        X.println("Training took: " + (System.currentTimeMillis() - before) + " ms.");
 
         byRegion.keySet().forEach(k -> {
             LabeledVectorDouble sp = byRegion.get(k).get(0);
             Tracer.showAscii(sp.vector());
-            System.out.println("Prediction: " + mdl.predict(sp.vector()) + "label: " + sp.doubleLabel());
+            X.println("Predicted value and label [pred=" + mdl.predict(sp.vector()) + ", label=" + sp.doubleLabel() + "]");
             assert mdl.predict(sp.vector()) == sp.doubleLabel();
         });
     }
@@ -307,16 +308,16 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
         ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer =
             new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
 
-        System.out.println(">>> Training started");
+        X.println("Training started.");
         long before = System.currentTimeMillis();
         DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
-        System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+        X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms.");
 
         Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1);
 
         IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE();
         Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
-        System.out.println(">>> MSE: " + accuracy);
+        X.println("MSE: " + accuracy);
     }
 
     /**
@@ -358,7 +359,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
                 for (int i = 0; i < vectorSize; i++)
                     batch.get(i).put(sampleIdx, next.getX(i));
 
-                System.out.println(sampleIdx);
+                X.println("Sample index: " + sampleIdx);
                 if (sampleIdx % batchSize == 0) {
                     batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi)));
                     IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>()));
@@ -396,7 +397,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
                 sampleIdx++;
 
                 if (sampleIdx % 1000 == 0)
-                    System.out.println(">>> Loaded " + sampleIdx + " vectors.");
+                    System.out.println("Loaded: " + sampleIdx + " vectors.");
             }
         }
     }