You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:02:06 UTC

[16/50] [abbrv] incubator-hivemall git commit: refine chi2

refine chi2



Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a16a3fde
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a16a3fde
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a16a3fde

Branch: refs/heads/JIRA-22/pr-385
Commit: a16a3fde844ba381dee7eb1e9608ddc2dcfb96fc
Parents: 6dc2344
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 13:10:18 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:35:33 2016 +0900

----------------------------------------------------------------------
 .../hivemall/ftvec/selection/ChiSquareUDF.java  | 40 +++++++------
 .../java/hivemall/utils/math/StatsUtils.java    | 62 +++++++++++---------
 2 files changed, 58 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index e2b7494..951aeeb 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -50,6 +50,12 @@ public class ChiSquareUDF extends GenericUDF {
     private ListObjectInspector expectedRowOI;
     private PrimitiveObjectInspector expectedElOI;
 
+    private int nFeatures = -1;
+    private double[] observedRow = null; // to reuse
+    private double[] expectedRow = null; // to reuse
+    private double[][] observed = null; // shape = (#features, #classes)
+    private double[][] expected = null; // shape = (#features, #classes)
+
     @Override
     public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
         if (OIs.length != 2) {
@@ -75,12 +81,12 @@ public class ChiSquareUDF extends GenericUDF {
         expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
         expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
 
-        List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
         fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
         fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
 
         return ObjectInspectorFactory.getStandardStructObjectInspector(
-            Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
+            Arrays.asList("chi2", "pvalue"), fieldOIs);
     }
 
     @Override
@@ -93,28 +99,28 @@ public class ChiSquareUDF extends GenericUDF {
         final int nClasses = observedObj.size();
         Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
 
-        int nFeatures = -1;
-        double[] observedRow = null; // to reuse
-        double[] expectedRow = null; // to reuse
-        double[][] observed = null; // shape = (#features, #classes)
-        double[][] expected = null; // shape = (#features, #classes)
-
         // explode and transpose matrix
         for (int i = 0; i < nClasses; i++) {
-            if (i == 0) {
+            final Object observedObjRow = observedObj.get(i);
+            final Object expectedObjRow = observedObj.get(i);
+
+            Preconditions.checkNotNull(observedObjRow);
+            Preconditions.checkNotNull(expectedObjRow);
+
+            if (observedRow == null) {
                 // init
-                observedRow = HiveUtils.asDoubleArray(observedObj.get(i), observedRowOI,
-                    observedElOI, false);
-                expectedRow = HiveUtils.asDoubleArray(expectedObj.get(i), expectedRowOI,
-                    expectedElOI, false);
+                observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI,
+                    false);
+                expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI,
+                    false);
                 nFeatures = observedRow.length;
                 observed = new double[nFeatures][nClasses];
                 expected = new double[nFeatures][nClasses];
             } else {
-                HiveUtils.toDoubleArray(observedObj.get(i), observedRowOI, observedElOI,
-                    observedRow, false);
-                HiveUtils.toDoubleArray(expectedObj.get(i), expectedRowOI, expectedElOI,
-                    expectedRow, false);
+                HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow,
+                    false);
+                HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow,
+                    false);
             }
 
             for (int j = 0; j < nFeatures; j++) {

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index d3b25c7..e255b84 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -23,11 +23,15 @@ import hivemall.utils.lang.Preconditions;
 import javax.annotation.Nonnull;
 
 import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotPositiveException;
 import org.apache.commons.math3.linear.DecompositionSolver;
 import org.apache.commons.math3.linear.LUDecomposition;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
 
 import java.util.AbstractMap;
 import java.util.Map;
@@ -194,54 +198,59 @@ public final class StatsUtils {
     }
 
     /**
-     * @param observed mean vector whose value is observed
-     * @param expected mean vector whose value is expected
+     * @param observed means non-negative vector
+     * @param expected means positive vector
      * @return chi2 value
      */
     public static double chiSquare(@Nonnull final double[] observed,
             @Nonnull final double[] expected) {
-        Preconditions.checkArgument(observed.length == expected.length);
+        if (observed.length < 2) {
+            throw new DimensionMismatchException(observed.length, 2);
+        }
+        if (expected.length != observed.length) {
+            throw new DimensionMismatchException(observed.length, expected.length);
+        }
+        MathArrays.checkPositive(expected);
+        for (double d : observed) {
+            if (d < 0.d) {
+                throw new NotPositiveException(d);
+            }
+        }
 
         double sumObserved = 0.d;
         double sumExpected = 0.d;
-
-        for (int ratio = 0; ratio < observed.length; ++ratio) {
-            sumObserved += observed[ratio];
-            sumExpected += expected[ratio];
+        for (int i = 0; i < observed.length; i++) {
+            sumObserved += observed[i];
+            sumExpected += expected[i];
         }
-
-        double var15 = 1.d;
+        double ratio = 1.d;
         boolean rescale = false;
-        if (Math.abs(sumObserved - sumExpected) > 1.e-5) {
-            var15 = sumObserved / sumExpected;
+        if (FastMath.abs(sumObserved - sumExpected) > 10e-6) {
+            ratio = sumObserved / sumExpected;
             rescale = true;
         }
-
         double sumSq = 0.d;
-
-        for (int i = 0; i < observed.length; ++i) {
-            double dev;
+        for (int i = 0; i < observed.length; i++) {
             if (rescale) {
-                dev = observed[i] - var15 * expected[i];
-                sumSq += dev * dev / (var15 * expected[i]);
+                final double dev = observed[i] - ratio * expected[i];
+                sumSq += dev * dev / (ratio * expected[i]);
             } else {
-                dev = observed[i] - expected[i];
+                final double dev = observed[i] - expected[i];
                 sumSq += dev * dev / expected[i];
             }
         }
-
         return sumSq;
     }
 
     /**
-     * @param observed means vector whose value is observed
-     * @param expected means vector whose value is expected
+     * @param observed means non-negative vector
+     * @param expected means positive vector
      * @return p value
      */
     public static double chiSquareTest(@Nonnull final double[] observed,
             @Nonnull final double[] expected) {
-        ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
-            (double) expected.length - 1.d);
+        final ChiSquaredDistribution distribution = new ChiSquaredDistribution(
+            expected.length - 1.d);
         return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected));
     }
 
@@ -249,8 +258,8 @@ public final class StatsUtils {
      * This method offers effective calculation for multiple entries rather than calculation
      * individually
      * 
-     * @param observeds means matrix whose values are observed
-     * @param expecteds means matrix
+     * @param observeds means non-negative matrix
+     * @param expecteds means positive matrix
      * @return (chi2 value[], p value[])
      */
     public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds,
@@ -260,8 +269,7 @@ public final class StatsUtils {
         final int len = expecteds.length;
         final int lenOfEach = expecteds[0].length;
 
-        final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
-            (double) lenOfEach - 1.d);
+        final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d);
 
         final double[] chi2s = new double[len];
         final double[] ps = new double[len];