You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:02:06 UTC
[16/50] [abbrv] incubator-hivemall git commit: refine chi2
refine chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a16a3fde
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a16a3fde
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a16a3fde
Branch: refs/heads/JIRA-22/pr-385
Commit: a16a3fde844ba381dee7eb1e9608ddc2dcfb96fc
Parents: 6dc2344
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 13:10:18 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:35:33 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 40 +++++++------
.../java/hivemall/utils/math/StatsUtils.java | 62 +++++++++++---------
2 files changed, 58 insertions(+), 44 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index e2b7494..951aeeb 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -50,6 +50,12 @@ public class ChiSquareUDF extends GenericUDF {
private ListObjectInspector expectedRowOI;
private PrimitiveObjectInspector expectedElOI;
+ private int nFeatures = -1;
+ private double[] observedRow = null; // to reuse
+ private double[] expectedRow = null; // to reuse
+ private double[][] observed = null; // shape = (#features, #classes)
+ private double[][] expected = null; // shape = (#features, #classes)
+
@Override
public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
if (OIs.length != 2) {
@@ -75,12 +81,12 @@ public class ChiSquareUDF extends GenericUDF {
expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
- List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(
- Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
+ Arrays.asList("chi2", "pvalue"), fieldOIs);
}
@Override
@@ -93,28 +99,28 @@ public class ChiSquareUDF extends GenericUDF {
final int nClasses = observedObj.size();
Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
- int nFeatures = -1;
- double[] observedRow = null; // to reuse
- double[] expectedRow = null; // to reuse
- double[][] observed = null; // shape = (#features, #classes)
- double[][] expected = null; // shape = (#features, #classes)
-
// explode and transpose matrix
for (int i = 0; i < nClasses; i++) {
- if (i == 0) {
+ final Object observedObjRow = observedObj.get(i);
+ final Object expectedObjRow = observedObj.get(i);
+
+ Preconditions.checkNotNull(observedObjRow);
+ Preconditions.checkNotNull(expectedObjRow);
+
+ if (observedRow == null) {
// init
- observedRow = HiveUtils.asDoubleArray(observedObj.get(i), observedRowOI,
- observedElOI, false);
- expectedRow = HiveUtils.asDoubleArray(expectedObj.get(i), expectedRowOI,
- expectedElOI, false);
+ observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI,
+ false);
+ expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI,
+ false);
nFeatures = observedRow.length;
observed = new double[nFeatures][nClasses];
expected = new double[nFeatures][nClasses];
} else {
- HiveUtils.toDoubleArray(observedObj.get(i), observedRowOI, observedElOI,
- observedRow, false);
- HiveUtils.toDoubleArray(expectedObj.get(i), expectedRowOI, expectedElOI,
- expectedRow, false);
+ HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow,
+ false);
+ HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow,
+ false);
}
for (int j = 0; j < nFeatures; j++) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index d3b25c7..e255b84 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -23,11 +23,15 @@ import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
import java.util.AbstractMap;
import java.util.Map;
@@ -194,54 +198,59 @@ public final class StatsUtils {
}
/**
- * @param observed mean vector whose value is observed
- * @param expected mean vector whose value is expected
+ * @param observed means non-negative vector
+ * @param expected means positive vector
* @return chi2 value
*/
public static double chiSquare(@Nonnull final double[] observed,
@Nonnull final double[] expected) {
- Preconditions.checkArgument(observed.length == expected.length);
+ if (observed.length < 2) {
+ throw new DimensionMismatchException(observed.length, 2);
+ }
+ if (expected.length != observed.length) {
+ throw new DimensionMismatchException(observed.length, expected.length);
+ }
+ MathArrays.checkPositive(expected);
+ for (double d : observed) {
+ if (d < 0.d) {
+ throw new NotPositiveException(d);
+ }
+ }
double sumObserved = 0.d;
double sumExpected = 0.d;
-
- for (int ratio = 0; ratio < observed.length; ++ratio) {
- sumObserved += observed[ratio];
- sumExpected += expected[ratio];
+ for (int i = 0; i < observed.length; i++) {
+ sumObserved += observed[i];
+ sumExpected += expected[i];
}
-
- double var15 = 1.d;
+ double ratio = 1.d;
boolean rescale = false;
- if (Math.abs(sumObserved - sumExpected) > 1.e-5) {
- var15 = sumObserved / sumExpected;
+ if (FastMath.abs(sumObserved - sumExpected) > 10e-6) {
+ ratio = sumObserved / sumExpected;
rescale = true;
}
-
double sumSq = 0.d;
-
- for (int i = 0; i < observed.length; ++i) {
- double dev;
+ for (int i = 0; i < observed.length; i++) {
if (rescale) {
- dev = observed[i] - var15 * expected[i];
- sumSq += dev * dev / (var15 * expected[i]);
+ final double dev = observed[i] - ratio * expected[i];
+ sumSq += dev * dev / (ratio * expected[i]);
} else {
- dev = observed[i] - expected[i];
+ final double dev = observed[i] - expected[i];
sumSq += dev * dev / expected[i];
}
}
-
return sumSq;
}
/**
- * @param observed means vector whose value is observed
- * @param expected means vector whose value is expected
+ * @param observed means non-negative vector
+ * @param expected means positive vector
* @return p value
*/
public static double chiSquareTest(@Nonnull final double[] observed,
@Nonnull final double[] expected) {
- ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
- (double) expected.length - 1.d);
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(
+ expected.length - 1.d);
return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected));
}
@@ -249,8 +258,8 @@ public final class StatsUtils {
* This method offers effective calculation for multiple entries rather than calculation
* individually
*
- * @param observeds means matrix whose values are observed
- * @param expecteds means matrix
+ * @param observeds means non-negative matrix
+ * @param expecteds means positive matrix
* @return (chi2 value[], p value[])
*/
public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds,
@@ -260,8 +269,7 @@ public final class StatsUtils {
final int len = expecteds.length;
final int lenOfEach = expecteds[0].length;
- final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
- (double) lenOfEach - 1.d);
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d);
final double[] chi2s = new double[len];
final double[] ps = new double[len];