You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2013/03/13 00:14:30 UTC

svn commit: r1455749 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java

Author: srowen
Date: Tue Mar 12 23:14:30 2013
New Revision: 1455749

URL: http://svn.apache.org/r1455749
Log:
MAHOUT-945 improve regression calculation for regression DF split

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java?rev=1455749&r1=1455748&r2=1455749&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java Tue Mar 12 23:14:30 2013
@@ -17,6 +17,7 @@
 
 package org.apache.mahout.classifier.df.split;
 
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
 import org.apache.mahout.classifier.df.data.Data;
 import org.apache.mahout.classifier.df.data.Instance;
 
@@ -25,8 +26,8 @@ import java.util.Arrays;
 import java.util.Comparator;
 
 /**
- * Regression problem implementation of IgSplit.
- * This class can be used when the criterion variable is the numerical attribute.
+ * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical
+ * attribute.
  */
 public class RegressionSplit extends IgSplit {
   
@@ -59,30 +60,44 @@ public class RegressionSplit extends IgS
    * Computes the split for a CATEGORICAL attribute
    */
   private static Split categoricalSplit(Data data, int attr) {
-    double[] sums = new double[data.getDataset().nbValues(attr)];
-    double[] sumSquared = new double[data.getDataset().nbValues(attr)];
-    double[] counts = new double[data.getDataset().nbValues(attr)];
-    double totalSum = 0;
-    double totalSumSquared = 0;
+    FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
+    double[] sk = new double[data.getDataset().nbValues(attr)];
+    for (int i = 0; i < ra.length; i++) {
+      ra[i] = new FullRunningAverage();
+    }
+    FullRunningAverage totalRa = new FullRunningAverage();
+    double totalSk = 0.0;
 
-    // sum and sum of squares
     for (int i = 0; i < data.size(); i++) {
+      // computes the variance
       Instance instance = data.get(i);
       int value = (int) instance.get(attr);
-      double label = data.getDataset().getLabel(instance);
-      double square = label * label;
+      double xk = data.getDataset().getLabel(instance);
+      if (ra[value].getCount() == 0) {
+        ra[value].addDatum(xk);
+        sk[value] = 0.0;
+      } else {
+        double mk = ra[value].getAverage();
+        ra[value].addDatum(xk);
+        sk[value] += (xk - mk) * (xk - ra[value].getAverage());
+      }
+
+      // total variance
+      if (i == 0) {
+        totalRa.addDatum(xk);
+        totalSk = 0.0;
+      } else {
+        double mk = totalRa.getAverage();
+        totalRa.addDatum(xk);
+        totalSk += (xk - mk) * (xk - totalRa.getAverage());
+      }
+    }
 
-      sums[value] += label;
-      sumSquared[value] += square;
-      counts[value]++;
-      totalSum += label;
-      totalSumSquared += square;
+    // computes the variance gain
+    double ig = totalSk;
+    for (int i = 0; i < sk.length; i++) {
+      ig -= sk[i];
     }
-    
-    // computes the variance
-    double totalVar = totalSumSquared - (totalSum * totalSum) / data.size();
-    double var = variance(sums, sumSquared, counts);
-    double ig = totalVar - var;
 
     return new Split(attr, ig);
   }
@@ -90,7 +105,12 @@ public class RegressionSplit extends IgS
   /**
    * Computes the best split for a NUMERICAL attribute
    */
-  static Split numericalSplit(Data data, int attr) {
+  private static Split numericalSplit(Data data, int attr) {
+    FullRunningAverage[] ra = new FullRunningAverage[2];
+    double[] sk = new double[2];
+    for (int i = 0; i < ra.length; i++) {
+      ra[i] = new FullRunningAverage();
+    }
 
     // Instance sort
     Instance[] instances = new Instance[data.size()];
@@ -99,81 +119,58 @@ public class RegressionSplit extends IgS
     }
     Arrays.sort(instances, new InstanceComparator(attr));
 
-    // sum and sum of squares
-    double totalSum = 0.0;
-    double totalSumSquared = 0.0;
     for (Instance instance : instances) {
-      double label = data.getDataset().getLabel(instance);
-      totalSum += label;
-      totalSumSquared += label * label;
-    }
-    double[] sums = new double[2];
-    double[] curSums = new double[2];
-    sums[1] = curSums[1] = totalSum;
-    double[] sumSquared = new double[2];
-    double[] curSumSquared = new double[2];
-    sumSquared[1] = curSumSquared[1] = totalSumSquared;
-    double[] counts = new double[2];
-    double[] curCounts = new double[2];
-    counts[1] = curCounts[1] = data.size();
+      double xk = data.getDataset().getLabel(instance);
+      if (ra[1].getCount() == 0) {
+        ra[1].addDatum(xk);
+        sk[1] = 0.0;
+      } else {
+        double mk = ra[1].getAverage();
+        ra[1].addDatum(xk);
+        sk[1] += (xk - mk) * (xk - ra[1].getAverage());
+      }
+    }
+    double totalSk = sk[1];
 
     // find the best split point
-    double curSplit = instances[0].get(attr);
-    double bestVal = Double.MAX_VALUE;
     double split = Double.NaN;
+    double preSplit = Double.NaN;
+    double bestVal = Double.MAX_VALUE;
+    double bestSk = 0.0;
+
+    // computes total variance
     for (Instance instance : instances) {
-      if (instance.get(attr) > curSplit) {
-        double curVal = variance(curSums, curSumSquared, curCounts);
+      double xk = data.getDataset().getLabel(instance);
+
+      if (instance.get(attr) > preSplit) {
+        double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount();
         if (curVal < bestVal) {
           bestVal = curVal;
-          split = (instance.get(attr) + curSplit) / 2.0;
-          for (int j = 0; j < 2; j++) {
-            sums[j] = curSums[j];
-            sumSquared[j] = curSumSquared[j];
-            counts[j] = curCounts[j];
-          }
+          bestSk = sk[0] + sk[1];
+          split = (instance.get(attr) + preSplit) / 2.0;
         }
       }
 
-      curSplit = instance.get(attr);
-
-      double label = data.getDataset().getLabel(instance);
-      double square = label * label;
+      // computes the variance
+      if (ra[0].getCount() == 0) {
+        ra[0].addDatum(xk);
+        sk[0] = 0.0;
+      } else {
+        double mk = ra[0].getAverage();
+        ra[0].addDatum(xk);
+        sk[0] += (xk - mk) * (xk - ra[0].getAverage());
+      }
 
-      curSums[0] += label;
-      curSumSquared[0] += square;
-      curCounts[0]++;
+      double mk = ra[1].getAverage();
+      ra[1].removeDatum(xk);
+      sk[1] -= (xk - mk) * (xk - ra[1].getAverage());
 
-      curSums[1] -= label;
-      curSumSquared[1] -= square;
-      curCounts[1]--;
+      preSplit = instance.get(attr);
     }
 
-    // computes the variance
-    double totalVar = totalSumSquared - (totalSum * totalSum) / data.size();
-    double var = variance(sums, sumSquared, counts);
-    double ig = totalVar - var;
+    // computes the variance gain
+    double ig = totalSk - bestSk;
 
     return new Split(attr, ig, split);
   }
-  
-  /**
-   * Computes the variance
-   * 
-   * @param s
-   *          data
-   * @param ss
-   *          squared data
-   * @param dataSize
-   *          numInstances
-   */
-  private static double variance(double[] s, double[] ss, double[] dataSize) {
-    double var = 0;
-    for (int i = 0; i < s.length; i++) {
-      if (dataSize[i] > 0) {
-        var += ss[i] - ((s[i] * s[i]) / dataSize[i]);
-      }
-    }
-    return var;
-  }
 }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java?rev=1455749&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java Tue Mar 12 23:14:30 2013
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class RegressionSplitTest extends MahoutTestCase {
+
+  private static Data[] generateTrainingData() throws DescriptorException {
+    // Training data
+    String[] trainData = new String[20];
+    for (int i = 0; i < trainData.length; i++) {
+      if (i % 3 == 0) {
+        trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+      } else if (i % 3 == 1) {
+        trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+      } else {
+        trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+      }
+    }
+    // Dataset
+    Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+    Data[] datas = new Data[3];
+    datas[0] = DataLoader.loadData(dataset, trainData);
+
+    // Training data
+    trainData = new String[20];
+    for (int i = 0; i < trainData.length; i++) {
+      if (i % 2 == 0) {
+        trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+      } else {
+        trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+      }
+    }
+    datas[1] = DataLoader.loadData(dataset, trainData);
+
+    // Training data
+    trainData = new String[10];
+    for (int i = 0; i < trainData.length; i++) {
+      trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+    }
+    datas[2] = DataLoader.loadData(dataset, trainData);
+
+    return datas;
+  }
+
+  @Test
+  public void testComputeSplit() throws DescriptorException {
+    Data[] datas = generateTrainingData();
+
+    RegressionSplit igSplit = new RegressionSplit();
+    Split split = igSplit.computeSplit(datas[0], 1);
+    assertEquals(180.0, split.getIg(), EPSILON);
+    assertEquals(38.0, split.getSplit(), EPSILON);
+    split = igSplit.computeSplit(datas[0].subset(Condition.lesser(1, 38.0)), 1);
+    assertEquals(76.5, split.getIg(), EPSILON);
+    assertEquals(21.5, split.getSplit(), EPSILON);
+
+    split = igSplit.computeSplit(datas[1], 0);
+    assertEquals(2205.0, split.getIg(), EPSILON);
+    assertEquals(Double.NaN, split.getSplit(), EPSILON);
+    split = igSplit.computeSplit(datas[1].subset(Condition.equals(0, 0.0)), 1);
+    assertEquals(250.0, split.getIg(), EPSILON);
+    assertEquals(41.0, split.getSplit(), EPSILON);
+  }
+}