You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/13 07:47:05 UTC
[1/2] incubator-hivemall git commit: Close #15: Implement Feature
Selection functions (chi2, snr)
Repository: incubator-hivemall
Updated Branches:
refs/heads/master 518e232d8 -> fad2941fd
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java b/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java
new file mode 100644
index 0000000..b0cfbd0
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/lang/PreconditionsTest.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.lang;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.junit.Test;
+
+public class PreconditionsTest {
+
+ @Test(expected = UDFArgumentException.class)
+ public void testCheckNotNullTClassOfE() throws UDFArgumentException {
+ Preconditions.checkNotNull(null, UDFArgumentException.class);
+ }
+
+ @Test(expected = HiveException.class)
+ public void testCheckArgumentBooleanClassOfE() throws UDFArgumentException, HiveException {
+ Preconditions.checkArgument(false, HiveException.class);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index c333c98..33bb46c 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -61,6 +61,8 @@
* [Vectorize Features](ft_engineering/vectorizer.md)
* [Quantify non-number features](ft_engineering/quantify.md)
+* [Feature selection](ft_engineering/feature_selection.md)
+
## Part IV - Evaluation
* [Statistical evaluation of a prediction model](eval/stat_eval.md)
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/docs/gitbook/ft_engineering/feature_selection.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/ft_engineering/feature_selection.md b/docs/gitbook/ft_engineering/feature_selection.md
new file mode 100644
index 0000000..5a2a92b
--- /dev/null
+++ b/docs/gitbook/ft_engineering/feature_selection.md
@@ -0,0 +1,155 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+[Feature Selection](https://en.wikipedia.org/wiki/Feature_selection) is the process of selecting a subset of relevant features for use in model construction.
+
+It is a useful technique to 1) improve prediction results by omitting redundant features, 2) to shorten training time, and 3) to know important features for prediction.
+
+*Note: The feature is supported Hivemall from v0.5-rc.1 or later.*
+
+<!-- toc -->
+
+# Supported Feature Selection algorithms
+
+* Chi-square (Chi2)
+ * In statistics, the $$\chi^2$$ test is applied to test the independence of two even events. Chi-square statistics between every feature variable and the target variable can be applied to Feature Selection. Refer [this article](http://nlp.stanford.edu/IR-book/html/htmledition/feature-selectionchi2-feature-selection-1.html) for Mathematical details.
+* Signal Noise Ratio (SNR)
+ * The Signal Noise Ratio (SNR) is a univariate feature ranking metric, which can be used as a feature selection criterion for binary classification problems. SNR is defined as $$|\mu_{1} - \mu_{2}| / (\sigma_{1} + \sigma_{2})$$, where $$\mu_{k}$$ is the mean value of the variable in classes $$k$$, and $$\sigma_{k}$$ is the standard deviations of the variable in classes $$k$$. Clearly, features with larger SNR are useful for classification.
+
+# Usage
+
+## Feature Selection based on Chi-square test
+
+``` sql
+CREATE TABLE input (
+ X array<double>, -- features
+ Y array<int> -- binarized label
+);
+
+set hivevar:k=2;
+
+WITH stats AS (
+ SELECT
+ transpose_and_dot(Y, X) AS observed, -- array<array<double>>, shape = (n_classes, n_features)
+ array_sum(X) AS feature_count, -- n_features col vector, shape = (1, array<double>)
+ array_avg(Y) AS class_prob -- n_class col vector, shape = (1, array<double>)
+ FROM
+ input
+),
+test AS (
+ SELECT
+ transpose_and_dot(class_prob, feature_count) AS expected -- array<array<double>>, shape = (n_class, n_features)
+ FROM
+ stats
+),
+chi2 AS (
+ SELECT
+ chi2(r.observed, l.expected) AS v -- struct<array<double>, array<double>>, each shape = (1, n_features)
+ FROM
+ test l
+ CROSS JOIN stats r
+)
+SELECT
+ select_k_best(l.X, r.v.chi2, ${k}) as features -- top-k feature selection based on chi2 score
+FROM
+ input l
+ CROSS JOIN chi2 r;
+```
+
+## Feature Selection based on Signal Noise Ratio (SNR)
+
+``` sql
+CREATE TABLE input (
+ X array<double>, -- features
+ Y array<int> -- binarized label
+);
+
+set hivevar:k=2;
+
+WITH snr AS (
+ SELECT snr(X, Y) AS snr -- aggregated SNR as array<double>, shape = (1, #features)
+ FROM input
+)
+SELECT
+ select_k_best(X, snr, ${k}) as features
+FROM
+ input
+ CROSS JOIN snr;
+```
+
+# Function signatures
+
+### [UDAF] `transpose_and_dot(X::array<number>, Y::array<number>)::array<array<double>>`
+
+##### Input
+
+| `array<number>` X | `array<number>` Y |
+| :-: | :-: |
+| a row of matrix | a row of matrix |
+
+##### Output
+
+| `array<array<double>>` dot product |
+| :-: |
+| `dot(X.T, Y)` of shape = (X.#cols, Y.#cols) |
+
+### [UDF] `select_k_best(X::array<number>, importance_list::array<number>, k::int)::array<double>`
+
+##### Input
+
+| `array<number>` X | `array<number>` importance_list | `int` k |
+| :-: | :-: | :-: |
+| feature vector | importance of each feature | the number of features to be selected |
+
+##### Output
+
+| `array<array<double>>` k-best features |
+| :-: |
+| top-k elements from feature vector `X` based on importance list |
+
+### [UDF] `chi2(observed::array<array<number>>, expected::array<array<number>>)::struct<array<double>, array<double>>`
+
+##### Input
+
+| `array<number>` observed | `array<number>` expected |
+| :-: | :-: |
+| observed features | expected features `dot(class_prob.T, feature_count)` |
+
+Both of `observed` and `expected` have a shape `(#classes, #features)`
+
+##### Output
+
+| `struct<array<double>, array<double>>` importance_list |
+| :-: |
+| chi2-value and p-value for each feature |
+
+### [UDAF] `snr(X::array<number>, Y::array<int>)::array<double>`
+
+##### Input
+
+| `array<number>` X | `array<int>` Y |
+| :-: | :-: |
+| feature vector | one hot label |
+
+##### Output
+
+| `array<double>` importance_list |
+| :-: |
+| Signal Noise Ratio for each feature |
+
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/docs/gitbook/ft_engineering/quantify.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/ft_engineering/quantify.md b/docs/gitbook/ft_engineering/quantify.md
index 952db53..1bfaa73 100644
--- a/docs/gitbook/ft_engineering/quantify.md
+++ b/docs/gitbook/ft_engineering/quantify.md
@@ -19,7 +19,7 @@
`quantified_features` is useful for transforming values of non-number columns to indexed numbers.
-*Note: The feature is supported Hivemall v0.4 or later.*
+*Note: The feature is supported from Hivemall v0.4 or later.*
```sql
desc train;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 1906de1..72835b1 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -199,6 +199,16 @@ CREATE FUNCTION zscore as 'hivemall.ftvec.scaling.ZScoreUDF' USING JAR '${hivema
DROP FUNCTION IF EXISTS l2_normalize;
CREATE FUNCTION l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF' USING JAR '${hivemall_jar}';
+---------------------------------
+-- Feature Selection functions --
+---------------------------------
+
+DROP FUNCTION IF EXISTS chi2;
+CREATE FUNCTION chi2 as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS snr;
+CREATE FUNCTION snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF' USING JAR '${hivemall_jar}';
+
--------------------
-- misc functions --
--------------------
@@ -386,6 +396,9 @@ CREATE FUNCTION to_string_array as 'hivemall.tools.array.ToStringArrayUDF' USING
DROP FUNCTION IF EXISTS array_intersect;
CREATE FUNCTION array_intersect as 'hivemall.tools.array.ArrayIntersectUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS select_k_best;
+CREATE FUNCTION select_k_best as 'hivemall.tools.array.SelectKBestUDF' USING JAR '${hivemall_jar}';
+
-----------------------------
-- bit operation functions --
-----------------------------
@@ -436,6 +449,13 @@ DROP FUNCTION IF EXISTS sigmoid;
CREATE FUNCTION sigmoid as 'hivemall.tools.math.SigmoidGenericUDF' USING JAR '${hivemall_jar}';
----------------------
+-- Matrix functions --
+----------------------
+
+DROP FUNCTION IF EXISTS transpose_and_dot;
+CREATE FUNCTION transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF' USING JAR '${hivemall_jar}';
+
+----------------------
-- mapred functions --
----------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 37d262a..351303e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -195,6 +195,16 @@ create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
drop temporary function l2_normalize;
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
+---------------------------------
+-- Feature Selection functions --
+---------------------------------
+
+drop temporary function chi2;
+create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+
+drop temporary function snr;
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
+
-----------------------------------
-- Feature engineering functions --
-----------------------------------
@@ -382,6 +392,9 @@ create temporary function to_string_array as 'hivemall.tools.array.ToStringArray
drop temporary function array_intersect;
create temporary function array_intersect as 'hivemall.tools.array.ArrayIntersectUDF';
+drop temporary function select_k_best;
+create temporary function select_k_best as 'hivemall.tools.array.SelectKBestUDF';
+
-----------------------------
-- bit operation functions --
-----------------------------
@@ -432,6 +445,13 @@ drop temporary function sigmoid;
create temporary function sigmoid as 'hivemall.tools.math.SigmoidGenericUDF';
----------------------
+-- Matrix functions --
+----------------------
+
+drop temporary function transpose_and_dot;
+create temporary function transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF';
+
+----------------------
-- mapred functions --
----------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 5de6106..838f7bb 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -6,7 +6,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version")
sqlContext.sql("CREATE TEMPORARY FUNCTION hivemall_version AS 'hivemall.HivemallVersionUDF'")
/**
- * binary classification
+ * Binary classification
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_perceptron")
@@ -59,7 +59,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_multiclass_scw2")
sqlContext.sql("CREATE TEMPORARY FUNCTION train_multiclass_scw2 AS 'hivemall.classifier.multiclass.MulticlassSoftConfidenceWeightedUDTF$SCW2'")
/**
- * similarity functions
+ * Similarity functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS cosine_sim")
@@ -78,7 +78,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS distance2similarity")
sqlContext.sql("CREATE TEMPORARY FUNCTION distance2similarity AS 'hivemall.knn.similarity.Distance2SimilarityUDF'")
/**
- * distance functions
+ * Distance functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS homming_distance")
@@ -122,7 +122,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bbit_minhash")
sqlContext.sql("CREATE TEMPORARY FUNCTION bbit_minhash AS 'hivemall.knn.lsh.bBitMinHashUDF'")
/**
- * voting functions
+ * Voting functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS voted_avg")
@@ -132,7 +132,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS weight_voted_avg")
sqlContext.sql("CREATE TEMPORARY FUNCTION weight_voted_avg AS 'hivemall.ensemble.bagging.WeightVotedAvgUDAF'")
/**
- * misc functions
+ * MISC functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS max_label")
@@ -145,7 +145,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS argmin_kld")
sqlContext.sql("CREATE TEMPORARY FUNCTION argmin_kld AS 'hivemall.ensemble.ArgminKLDistanceUDAF'")
/**
- * hashing functions
+ * Feature hashing functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS mhash")
@@ -161,7 +161,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS prefixed_hash_values")
sqlContext.sql("CREATE TEMPORARY FUNCTION prefixed_hash_values AS 'hivemall.ftvec.hashing.ArrayPrefixedHashValuesUDF'")
/**
- * pairing functions
+ * Feature pairing functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS polynomial_features")
@@ -171,7 +171,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS powered_features")
sqlContext.sql("CREATE TEMPORARY FUNCTION powered_features AS 'hivemall.ftvec.pairing.PoweredFeaturesUDF'")
/**
- * scaling functions
+ * Feature scaling functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS rescale")
@@ -184,7 +184,17 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS normalize")
sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 'hivemall.ftvec.scaling.L2NormalizationUDF'")
/**
- * misc functions
+ * Feature selection functions
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi2")
+sqlContext.sql("CREATE TEMPORARY FUNCTION chi2 AS 'hivemall.ftvec.selection.ChiSquareUDF'")
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS snr")
+sqlContext.sql("CREATE TEMPORARY FUNCTION snr AS 'hivemall.ftvec.selection.SignalNoiseRatioUDAF'")
+
+/**
+ * MISC feature engineering functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS amplify")
@@ -257,7 +267,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS tf")
sqlContext.sql("CREATE TEMPORARY FUNCTION tf AS 'hivemall.ftvec.text.TermFrequencyUDAF'")
/**
- * fegression functions
+ * Regression functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_logregr")
@@ -291,7 +301,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_arowe2_regr")
sqlContext.sql("CREATE TEMPORARY FUNCTION train_arow_regr AS 'hivemall.regression.AROWRegressionUDTF$AROWe2'")
/**
- * array functions
+ * Array functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS float_array")
@@ -321,8 +331,11 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION subarray AS 'hivemall.tools.array.Suba
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_avg")
sqlContext.sql("CREATE TEMPORARY FUNCTION array_avg AS 'hivemall.tools.array.ArrayAvgGenericUDAF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS select_k_best")
+sqlContext.sql("CREATE TEMPORARY FUNCTION select_k_best AS 'hivemall.tools.array.SelectKBestUDF'")
+
/**
- * compression functions
+ * Compression functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS inflate")
@@ -332,7 +345,7 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS deflate")
sqlContext.sql("CREATE TEMPORARY FUNCTION deflate AS 'hivemall.tools.compress.DeflateUDF'")
/**
- * map functions
+ * Map functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS map_get_sum")
@@ -355,14 +368,21 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sigmoid")
sqlContext.sql("CREATE TEMPORARY FUNCTION sigmoid AS 'hivemall.tools.math.SigmoidGenericUDF'")
/**
- * mapred functions
+ * Matrix functions
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS transpose_and_dot")
+sqlContext.sql("CREATE TEMPORARY FUNCTION transpose_and_dot AS 'hivemall.tools.matrix.TransposeAndDotUDAF'")
+
+/**
+ * MAPRED functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS rowid")
sqlContext.sql("CREATE TEMPORARY FUNCTION rowid AS 'hivemall.tools.mapred.RowIdUDFWrapper'")
/**
- * misc functions
+ * MISC functions
*/
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS generate_series")
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index a0bea45..47dbd1d 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -50,6 +50,8 @@ create temporary function powered_features as 'hivemall.ftvec.pairing.PoweredFea
create temporary function rescale as 'hivemall.ftvec.scaling.RescaleUDF';
create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
+create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifierUDTF';
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';
@@ -101,6 +103,7 @@ create temporary function array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF
create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';
create temporary function array_intersect as 'hivemall.tools.array.ArrayIntersectUDF';
+create temporary function select_k_best as 'hivemall.tools.array.SelectKBestUDF';
create temporary function bits_collect as 'hivemall.tools.bits.BitsCollectUDAF';
create temporary function to_bits as 'hivemall.tools.bits.ToBitsUDF';
create temporary function unbits as 'hivemall.tools.bits.UnBitsUDF';
@@ -112,6 +115,7 @@ create temporary function map_tail_n as 'hivemall.tools.map.MapTailNUDF';
create temporary function to_map as 'hivemall.tools.map.UDAFToMap';
create temporary function to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap';
create temporary function sigmoid as 'hivemall.tools.math.SigmoidGenericUDF';
+create temporary function transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF';
create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF';
create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF';
create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF';
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index fd4da64..dd6db6c 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@ -271,9 +271,32 @@ final class GroupedDataEx protected[sql](
*/
def onehot_encoding(features: String*): DataFrame = {
val udaf = HiveUDAFFunction(
- new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
- features.map(df.col(_).expr),
- isUDAFBridgeRequired = false)
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ features.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
+
+ /**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
+
+ /**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF(Seq(Alias(udaf, udaf.prettyString)()))
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 4c750dd..8583e1c 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1010,6 +1010,15 @@ object HivemallOps {
}
/**
+ * @see hivemall.ftvec.selection.ChiSquareUDF
+ * @group ftvec.selection
+ */
+ def chi2(observed: Column, expected: Column): Column = {
+ HiveGenericUDF(new HiveFunctionWrapper(
+ "hivemall.ftvec.selection.ChiSquareUDF"), Seq(observed.expr, expected.expr))
+ }
+
+ /**
* @see hivemall.ftvec.conv.ToDenseFeaturesUDF
* @group ftvec.conv
*/
@@ -1082,6 +1091,15 @@ object HivemallOps {
}
/**
+ * @see hivemall.tools.array.SelectKBestUDF
+ * @group tools.array
+ */
+ def select_k_best(X: Column, importanceList: Column, k: Column): Column = {
+ HiveGenericUDF(new HiveFunctionWrapper(
+ "hivemall.tools.array.SelectKBestUDF"), Seq(X.expr, importanceList.expr, k.expr))
+ }
+
+ /**
* @see hivemall.tools.math.SigmoidUDF
* @group misc
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 901056d..4c77f18 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -187,6 +187,35 @@ final class HivemallOpsSuite extends HivemallQueryTest {
Row(Seq("1:1.0"))))
}
+ test("ftvec.selection - chi2") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.ChiSquareUDFTest
+ val df = Seq(
+ Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
+ ) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
+ .toDF("arg0", "arg1")
+
+ val result = df.select(chi2(df("arg0"), df("arg1"))).collect
+ assert(result.length == 1)
+ val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
+ val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
+
+ (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
test("ftvec.conv - quantify") {
import hiveContext.implicits._
val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
@@ -336,6 +365,17 @@ final class HivemallOpsSuite extends HivemallQueryTest {
checkAnswer(predicted, Seq(Row(0), Row(1)))
}
+ test("tools.array - select_k_best") {
+ import hiveContext.implicits._
+
+ val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
+ val df = data.map(d => (d, Seq(3, 1, 2), 2)).toDF("features", "importance_list", "k")
+
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect ===
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
+ }
+
test("misc - sigmoid") {
import hiveContext.implicits._
/**
@@ -534,14 +574,13 @@ final class HivemallOpsSuite extends HivemallQueryTest {
assert(row4(0).getDouble(1) ~== 0.25)
}
- test("user-defined aggregators for ftvec.trans") {
+ ignore("user-defined aggregators for ftvec.trans") {
import hiveContext.implicits._
val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
(1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
(1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
- .toDF("col0", "cat1", "cat2", "cat3")
-
+ .toDF("col0", "cat1", "cat2", "cat3")
val row00 = df0.groupby($"col0").onehot_encoding("cat1")
val row01 = df0.groupby($"col0").onehot_encoding("cat1", "cat2", "cat3")
@@ -560,4 +599,64 @@ final class HivemallOpsSuite extends HivemallQueryTest {
assert(result012.keySet === Set(1, 3, 9, 10, 101))
assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
}
+
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() ===
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index 8ac7185..bdeff98 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@ -133,6 +133,19 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
}
/**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "transpose_and_dot",
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
* @see hivemall.ftvec.trans.OnehotEncodingUDAF
* @group ftvec.trans
*/
@@ -147,6 +160,19 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
}
/**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "snr",
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
* @see hivemall.evaluation.MeanAbsoluteErrorUDAF
* @group evaluation
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index ba58039..9bde84f 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1216,6 +1216,16 @@ object HivemallOps {
}
/**
+ * @see hivemall.ftvec.selection.ChiSquareUDF
+ * @group ftvec.selection
+ */
+ def chi2(observed: Column, expected: Column): Column = withExpr {
+ HiveGenericUDF("chi2",
+ new HiveFunctionWrapper("hivemall.ftvec.selection.ChiSquareUDF"),
+ Seq(observed.expr, expected.expr))
+ }
+
+ /**
* @see hivemall.ftvec.conv.ToDenseFeaturesUDF
* @group ftvec.conv
*/
@@ -1295,6 +1305,16 @@ object HivemallOps {
}
/**
+ * @see hivemall.tools.array.SelectKBestUDF
+ * @group tools.array
+ */
+ def select_k_best(X: Column, importanceList: Column, k: Column): Column = withExpr {
+ HiveGenericUDF("select_k_best",
+ new HiveFunctionWrapper("hivemall.tools.array.SelectKBestUDF"),
+ Seq(X.expr, importanceList.expr, k.expr))
+ }
+
+ /**
* @see hivemall.tools.math.SigmoidUDF
* @group misc
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index a093e07..6f2f016 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -188,6 +188,35 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
Row(Seq("1:1.0"))))
}
+ test("ftvec.selection - chi2") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.ChiSquareUDFTest
+ val df = Seq(
+ Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
+ ) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
+ .toDF("arg0", "arg1")
+
+ val result = df.select(chi2(df("arg0"), df("arg1"))).collect
+ assert(result.length == 1)
+ val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
+ val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
+
+ (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
test("ftvec.conv - quantify") {
import hiveContext.implicits._
val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
@@ -361,6 +390,18 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
checkAnswer(predicted, Seq(Row(0), Row(1)))
}
+ test("tools.array - select_k_best") {
+ import hiveContext.implicits._
+ import org.apache.spark.sql.functions._
+
+ val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
+ val df = data.map(d => (d, Seq(3, 1, 2))).toDF("features", "importance_list")
+ val k = 2
+
+ checkAnswer(df.select(select_k_best(df("features"), df("importance_list"), lit(k))),
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
+ }
+
test("misc - sigmoid") {
import hiveContext.implicits._
/**
@@ -661,6 +702,65 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
assert(result012.keySet === Set(1, 3, 9, 10, 101))
assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
}
+
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"),
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
+ }
}
final class HivemallOpsWithVectorSuite extends VectorQueryTest {
[2/2] incubator-hivemall git commit: Close #15: Implement Feature
Selection functions (chi2, snr)
Posted by my...@apache.org.
Close #15: Implement Feature Selection functions (chi2, snr)
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/fad2941f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/fad2941f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/fad2941f
Branch: refs/heads/master
Commit: fad2941fdb0309dd6fbf22f2f936cbf0003b1c4a
Parents: 518e232
Author: amaya <am...@users.noreply.github.com>
Authored: Tue Dec 13 16:46:07 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Tue Dec 13 16:46:07 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 173 +++++++++
.../ftvec/selection/SignalNoiseRatioUDAF.java | 370 +++++++++++++++++++
.../hivemall/tools/array/SelectKBestUDF.java | 163 ++++++++
.../tools/matrix/TransposeAndDotUDAF.java | 222 +++++++++++
.../java/hivemall/utils/hadoop/HiveUtils.java | 22 +-
.../hivemall/utils/hadoop/WritableUtils.java | 16 +
.../java/hivemall/utils/lang/Preconditions.java | 30 ++
.../java/hivemall/utils/math/StatsUtils.java | 91 +++++
.../ftvec/selection/ChiSquareUDFTest.java | 82 ++++
.../selection/SignalNoiseRatioUDAFTest.java | 342 +++++++++++++++++
.../tools/array/SelectKBeatUDFTest.java | 69 ++++
.../tools/matrix/TransposeAndDotUDAFTest.java | 59 +++
.../hivemall/utils/lang/PreconditionsTest.java | 37 ++
docs/gitbook/SUMMARY.md | 2 +
.../gitbook/ft_engineering/feature_selection.md | 155 ++++++++
docs/gitbook/ft_engineering/quantify.md | 2 +-
resources/ddl/define-all-as-permanent.hive | 20 +
resources/ddl/define-all.hive | 20 +
resources/ddl/define-all.spark | 50 ++-
resources/ddl/define-udfs.td.hql | 4 +
.../apache/spark/sql/hive/GroupedDataEx.scala | 29 +-
.../org/apache/spark/sql/hive/HivemallOps.scala | 18 +
.../spark/sql/hive/HivemallOpsSuite.scala | 105 +++++-
.../spark/sql/hive/HivemallGroupedDataset.scala | 26 ++
.../org/apache/spark/sql/hive/HivemallOps.scala | 20 +
.../spark/sql/hive/HivemallOpsSuite.scala | 100 +++++
26 files changed, 2203 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
new file mode 100644
index 0000000..9ada4e5
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.StatsUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "chi2",
+ value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)"
+ + " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
+@UDFType(deterministic = true, stateful = false)
+public final class ChiSquareUDF extends GenericUDF {
+
+ private ListObjectInspector observedOI;
+ private ListObjectInspector observedRowOI;
+ private PrimitiveObjectInspector observedElOI;
+ private ListObjectInspector expectedOI;
+ private ListObjectInspector expectedRowOI;
+ private PrimitiveObjectInspector expectedElOI;
+
+ private int nFeatures = -1;
+ private double[] observedRow = null; // to reuse
+ private double[] expectedRow = null; // to reuse
+ private double[][] observed = null; // shape = (#features, #classes)
+ private double[][] expected = null; // shape = (#features, #classes)
+
+ private List<DoubleWritable>[] result;
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length);
+ }
+ if (!HiveUtils.isNumberListListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<array<number>> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `observed`");
+ }
+ if (!HiveUtils.isNumberListListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<array<number>> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `expected`");
+ }
+
+ this.observedOI = HiveUtils.asListOI(OIs[1]);
+ this.observedRowOI = HiveUtils.asListOI(observedOI.getListElementObjectInspector());
+ this.observedElOI = HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector());
+ this.expectedOI = HiveUtils.asListOI(OIs[0]);
+ this.expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
+ this.expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
+ this.result = new List[2];
+
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(
+ Arrays.asList("chi2", "pvalue"), fieldOIs);
+ }
+
+ @Override
+ public List<DoubleWritable>[] evaluate(DeferredObject[] dObj) throws HiveException {
+ List<?> observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features)
+ List<?> expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features)
+
+ if (observedObj == null || expectedObj == null) {
+ return null;
+ }
+
+ final int nClasses = observedObj.size();
+ Preconditions.checkArgument(nClasses == expectedObj.size(), UDFArgumentException.class);
+
+ // explode and transpose matrix
+ for (int i = 0; i < nClasses; i++) {
+ Object observedObjRow = observedObj.get(i);
+ Object expectedObjRow = expectedObj.get(i);
+
+ Preconditions.checkNotNull(observedObjRow, UDFArgumentException.class);
+ Preconditions.checkNotNull(expectedObjRow, UDFArgumentException.class);
+
+ if (observedRow == null) {
+ observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI,
+ false);
+ expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI,
+ false);
+ nFeatures = observedRow.length;
+ observed = new double[nFeatures][nClasses];
+ expected = new double[nFeatures][nClasses];
+ } else {
+ HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow,
+ false);
+ HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow,
+ false);
+ }
+
+ for (int j = 0; j < nFeatures; j++) {
+ observed[j][i] = observedRow[j];
+ expected[j][i] = expectedRow[j];
+ }
+ }
+
+ Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquare(observed, expected);
+
+ result[0] = WritableUtils.toWritableList(chi2.getKey(), result[0]);
+ result[1] = WritableUtils.toWritableList(chi2.getValue(), result[1]);
+ return result;
+ }
+
+ @Override
+ public void close() throws IOException {
+ // help GC
+ this.observedRow = null;
+ this.expectedRow = null;
+ this.observed = null;
+ this.expected = null;
+ this.result = null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("chi2");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
new file mode 100644
index 0000000..da0de59
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.SizeOf;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label)"
+ + " - Returns Signal Noise Ratio for each feature as array<double>")
+public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+ throws SemanticException {
+ final ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length);
+ }
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `features`");
+ }
+ if (!HiveUtils.isListOI(OIs[1])
+ || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<int> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `labels`");
+ }
+
+ return new SignalNoiseRatioUDAFEvaluator();
+ }
+
+ static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector featuresOI;
+ private PrimitiveObjectInspector featureOI;
+ private ListObjectInspector labelsOI;
+ private PrimitiveObjectInspector labelOI;
+
+ // PARTIAL2 and FINAL
+ private StructObjectInspector structOI;
+ private StructField countsField, meansField, variancesField;
+ private ListObjectInspector countsOI;
+ private LongObjectInspector countOI;
+ private ListObjectInspector meansOI;
+ private ListObjectInspector meanListOI;
+ private DoubleObjectInspector meanElemOI;
+ private ListObjectInspector variancesOI;
+ private ListObjectInspector varianceListOI;
+ private DoubleObjectInspector varianceElemOI;
+
+ @AggregationType(estimable = true)
+ static class SignalNoiseRatioAggregationBuffer extends AbstractAggregationBuffer {
+ long[] counts;
+ double[][] means;
+ double[][] variances;
+
+ @Override
+ public int estimate() {
+ return counts == null ? 0 : SizeOf.LONG * counts.length + SizeOf.DOUBLE
+ * means.length * means[0].length + SizeOf.DOUBLE * variances.length
+ * variances[0].length;
+ }
+
+ public void init(int nClasses, int nFeatures) {
+ this.counts = new long[nClasses];
+ this.means = new double[nClasses][nFeatures];
+ this.variances = new double[nClasses][nFeatures];
+ }
+
+ public void reset() {
+ if (counts != null) {
+ Arrays.fill(counts, 0);
+ for (double[] mean : means) {
+ Arrays.fill(mean, 0.d);
+ }
+ for (double[] variance : variances) {
+ Arrays.fill(variance, 0.d);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
+ this.featuresOI = HiveUtils.asListOI(OIs[0]);
+ this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+ this.labelsOI = HiveUtils.asListOI(OIs[1]);
+ this.labelOI = HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
+ } else {// from partial aggregation
+ this.structOI = (StructObjectInspector) OIs[0];
+ this.countsField = structOI.getStructFieldRef("counts");
+ this.countsOI = HiveUtils.asListOI(countsField.getFieldObjectInspector());
+ this.countOI = HiveUtils.asLongOI(countsOI.getListElementObjectInspector());
+ this.meansField = structOI.getStructFieldRef("means");
+ this.meansOI = HiveUtils.asListOI(meansField.getFieldObjectInspector());
+ this.meanListOI = HiveUtils.asListOI(meansOI.getListElementObjectInspector());
+ this.meanElemOI = HiveUtils.asDoubleOI(meanListOI.getListElementObjectInspector());
+ this.variancesField = structOI.getStructFieldRef("variances");
+ this.variancesOI = HiveUtils.asListOI(variancesField.getFieldObjectInspector());
+ this.varianceListOI = HiveUtils.asListOI(variancesOI.getListElementObjectInspector());
+ this.varianceElemOI = HiveUtils.asDoubleOI(varianceListOI.getListElementObjectInspector());
+ }
+
+ // initialize output
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ return ObjectInspectorFactory.getStandardStructObjectInspector(
+ Arrays.asList("counts", "means", "variances"), fieldOIs);
+ } else {// terminate
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ final Object featuresObj = parameters[0];
+ final Object labelsObj = parameters[1];
+
+ Preconditions.checkNotNull(featuresObj);
+ Preconditions.checkNotNull(labelsObj);
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final List<?> labels = labelsOI.getList(labelsObj);
+ final int nClasses = labels.size();
+ Preconditions.checkArgument(nClasses >= 2, UDFArgumentException.class);
+
+ final List<?> features = featuresOI.getList(featuresObj);
+ final int nFeatures = features.size();
+ Preconditions.checkArgument(nFeatures >= 1, UDFArgumentException.class);
+
+ if (myAgg.counts == null) {
+ myAgg.init(nClasses, nFeatures);
+ } else {
+ Preconditions.checkArgument(nClasses == myAgg.counts.length,
+ UDFArgumentException.class);
+ Preconditions.checkArgument(nFeatures == myAgg.means[0].length,
+ UDFArgumentException.class);
+ }
+
+ // incrementally calculates means and variance
+ // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
+ final int clazz = hotIndex(labels, labelOI);
+ final long n = myAgg.counts[clazz];
+ myAgg.counts[clazz]++;
+ for (int i = 0; i < nFeatures; i++) {
+ final double x = PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
+ final double meanN = myAgg.means[clazz][i];
+ final double varianceN = myAgg.variances[clazz][i];
+ myAgg.means[clazz][i] = (n * meanN + x) / (n + 1.d);
+ myAgg.variances[clazz][i] = (n * varianceN + (x - meanN)
+ * (x - myAgg.means[clazz][i]))
+ / (n + 1.d);
+ }
+ }
+
+ private static int hotIndex(@Nonnull List<?> labels, PrimitiveObjectInspector labelOI)
+ throws UDFArgumentException {
+ final int nClasses = labels.size();
+
+ int clazz = -1;
+ for (int i = 0; i < nClasses; i++) {
+ final int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+ if (label == 1) {// assumes one hot encoding
+ if (clazz != -1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array. Multiple hot elements found.");
+ }
+ clazz = i;
+ } else {
+ if (label != 0) {
+ throw new UDFArgumentException(
+ "Assumed one-hot encoding (0/1) but found an invalid label: " + label);
+ }
+ }
+ }
+ if (clazz == -1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array for label. Hot element not found.");
+ }
+ return clazz;
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object other)
+ throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final List<?> counts = countsOI.getList(structOI.getStructFieldData(other, countsField));
+ final List<?> means = meansOI.getList(structOI.getStructFieldData(other, meansField));
+ final List<?> variances = variancesOI.getList(structOI.getStructFieldData(other,
+ variancesField));
+
+ final int nClasses = counts.size();
+ final int nFeatures = meanListOI.getListLength(means.get(0));
+ if (myAgg.counts == null) {
+ myAgg.init(nClasses, nFeatures);
+ }
+
+ for (int i = 0; i < nClasses; i++) {
+ final long n = myAgg.counts[i];
+ final long cnt = PrimitiveObjectInspectorUtils.getLong(counts.get(i), countOI);
+
+ // no need to merge class `i`
+ if (cnt == 0) {
+ continue;
+ }
+
+ final List<?> mean = meanListOI.getList(means.get(i));
+ final List<?> variance = varianceListOI.getList(variances.get(i));
+
+ myAgg.counts[i] += cnt;
+ for (int j = 0; j < nFeatures; j++) {
+ final double meanN = myAgg.means[i][j];
+ final double meanM = PrimitiveObjectInspectorUtils.getDouble(mean.get(j),
+ meanElemOI);
+ final double varianceN = myAgg.variances[i][j];
+ final double varianceM = PrimitiveObjectInspectorUtils.getDouble(
+ variance.get(j), varianceElemOI);
+
+ if (n == 0) {// only assign `other` into `myAgg`
+ myAgg.means[i][j] = meanM;
+ myAgg.variances[i][j] = varianceM;
+ } else {
+ // merge by Chan's method
+ // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
+ myAgg.means[i][j] = (n * meanN + cnt * meanM) / (double) (n + cnt);
+ myAgg.variances[i][j] = (varianceN * (n - 1) + varianceM * (cnt - 1) + Math.pow(
+ meanN - meanM, 2) * n * cnt / (n + cnt))
+ / (n + cnt - 1);
+ }
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final Object[] partialResult = new Object[3];
+ partialResult[0] = WritableUtils.toWritableList(myAgg.counts);
+ final List<List<DoubleWritable>> means = new ArrayList<List<DoubleWritable>>();
+ for (double[] mean : myAgg.means) {
+ means.add(WritableUtils.toWritableList(mean));
+ }
+ partialResult[1] = means;
+ final List<List<DoubleWritable>> variances = new ArrayList<List<DoubleWritable>>();
+ for (double[] variance : myAgg.variances) {
+ variances.add(WritableUtils.toWritableList(variance));
+ }
+ partialResult[2] = variances;
+ return partialResult;
+ }
+
+ @Override
+ public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final int nClasses = myAgg.counts.length;
+ final int nFeatures = myAgg.means[0].length;
+
+ // compute SNR among classes for each feature
+ final double[] result = new double[nFeatures];
+ final double[] sds = new double[nClasses]; // for memorization
+ for (int i = 0; i < nFeatures; i++) {
+ sds[0] = Math.sqrt(myAgg.variances[0][i]);
+ for (int j = 1; j < nClasses; j++) {
+ sds[j] = Math.sqrt(myAgg.variances[j][i]);
+ // `ns[j] == 0` means no feature entry belongs to class `j`. Then, skip the entry.
+ if (myAgg.counts[j] == 0) {
+ continue;
+ }
+ for (int k = 0; k < j; k++) {
+ // avoid comparing between classes having only single entry
+ if (myAgg.counts[k] == 0 || (myAgg.counts[j] == 1 && myAgg.counts[k] == 1)) {
+ continue;
+ }
+
+ // SUM(snr) GROUP BY feature
+ final double snr = Math.abs(myAgg.means[j][i] - myAgg.means[k][i])
+ / (sds[j] + sds[k]);
+ // if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal),
+ // regard feature `i` as meaningless between class `j` and `k`. So, skip the entry.
+ if (!Double.isNaN(snr)) {
+ result[i] += snr; // accept `Infinity`
+ }
+ }
+ }
+ }
+
+ return WritableUtils.toWritableList(result);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
new file mode 100644
index 0000000..b363166
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+
+import java.io.IOException;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(name = "select_k_best",
+ value = "_FUNC_(array<number> array, const array<number> importance, const int k)"
+ + " - Returns selected top-k elements as array<double>")
+@UDFType(deterministic = true, stateful = false)
+public final class SelectKBestUDF extends GenericUDF {
+
+ private ListObjectInspector featuresOI;
+ private PrimitiveObjectInspector featureOI;
+ private ListObjectInspector importanceListOI;
+ private PrimitiveObjectInspector importanceElemOI;
+
+ private int _k;
+ private List<DoubleWritable> _result;
+ private int[] _topKIndices;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 3) {
+ throw new UDFArgumentLengthException("Specify three arguments: " + OIs.length);
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `features`");
+ }
+ if (!HiveUtils.isNumberListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<number> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `importance_list`");
+ }
+ if (!HiveUtils.isIntegerOI(OIs[2])) {
+ throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but "
+ + OIs[2].getTypeName() + " was passed as `k`");
+ }
+
+ this.featuresOI = HiveUtils.asListOI(OIs[0]);
+ this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+ this.importanceListOI = HiveUtils.asListOI(OIs[1]);
+ this.importanceElemOI = HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector());
+
+ this._k = HiveUtils.getConstInt(OIs[2]);
+ Preconditions.checkArgument(_k >= 1, UDFArgumentException.class);
+ final DoubleWritable[] array = new DoubleWritable[_k];
+ for (int i = 0; i < array.length; i++) {
+ array[i] = new DoubleWritable();
+ }
+ this._result = Arrays.asList(array);
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+
+ @Override
+ public List<DoubleWritable> evaluate(DeferredObject[] dObj) throws HiveException {
+ final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), featuresOI, featureOI);
+ final double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI,
+ importanceElemOI);
+
+ Preconditions.checkNotNull(features, UDFArgumentException.class);
+ Preconditions.checkNotNull(importanceList, UDFArgumentException.class);
+ Preconditions.checkArgument(features.length == importanceList.length,
+ UDFArgumentException.class);
+ Preconditions.checkArgument(features.length >= _k, UDFArgumentException.class);
+
+ int[] topKIndices = _topKIndices;
+ if (topKIndices == null) {
+ final List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
+ for (int i = 0; i < importanceList.length; i++) {
+ list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, importanceList[i]));
+ }
+ Collections.sort(list, new Comparator<Map.Entry<Integer, Double>>() {
+ @Override
+ public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
+ return o1.getValue() > o2.getValue() ? -1 : 1;
+ }
+ });
+
+ topKIndices = new int[_k];
+ for (int i = 0; i < topKIndices.length; i++) {
+ topKIndices[i] = list.get(i).getKey();
+ }
+ this._topKIndices = topKIndices;
+ }
+
+ final List<DoubleWritable> result = _result;
+ for (int i = 0; i < topKIndices.length; i++) {
+ int idx = topKIndices[i];
+ DoubleWritable d = result.get(i);
+ double f = features[idx];
+ d.set(f);
+ }
+ return result;
+ }
+
+ @Override
+ public void close() throws IOException {
+ // help GC
+ this._result = null;
+ this._topKIndices = null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("select_k_best");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
new file mode 100644
index 0000000..440bbe6
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.SizeOf;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+@Description(
+ name = "transpose_and_dot",
+ value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)"
+ + " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
+public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+ throws SemanticException {
+ ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `matrix0_row`");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<number> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `matrix1_row`");
+ }
+
+ return new TransposeAndDotUDAFEvaluator();
+ }
+
+ static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector matrix0RowOI;
+ private PrimitiveObjectInspector matrix0ElOI;
+ private ListObjectInspector matrix1RowOI;
+ private PrimitiveObjectInspector matrix1ElOI;
+
+ // PARTIAL2 and FINAL
+ private ListObjectInspector aggMatrixOI;
+ private ListObjectInspector aggMatrixRowOI;
+ private DoubleObjectInspector aggMatrixElOI;
+
+ private double[] matrix0Row;
+ private double[] matrix1Row;
+
+ @AggregationType(estimable = true)
+ static class TransposeAndDotAggregationBuffer extends AbstractAggregationBuffer {
+ double[][] aggMatrix;
+
+ @Override
+ public int estimate() {
+ return aggMatrix != null ? aggMatrix.length * aggMatrix[0].length * SizeOf.DOUBLE
+ : 0;
+ }
+
+ public void init(int n, int m) {
+ this.aggMatrix = new double[n][m];
+ }
+
+ public void reset() {
+ if (aggMatrix != null) {
+ for (double[] row : aggMatrix) {
+ Arrays.fill(row, 0.d);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+ this.matrix0RowOI = HiveUtils.asListOI(OIs[0]);
+ this.matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
+ this.matrix1RowOI = HiveUtils.asListOI(OIs[1]);
+ this.matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
+ } else {
+ this.aggMatrixOI = HiveUtils.asListOI(OIs[0]);
+ this.aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+ this.aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ final Object matrix0RowObj = parameters[0];
+ final Object matrix1RowObj = parameters[1];
+
+ Preconditions.checkNotNull(matrix0RowObj, UDFArgumentException.class);
+ Preconditions.checkNotNull(matrix1RowObj, UDFArgumentException.class);
+
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ if (matrix0Row == null) {
+ matrix0Row = new double[matrix0RowOI.getListLength(matrix0RowObj)];
+ }
+ if (matrix1Row == null) {
+ matrix1Row = new double[matrix1RowOI.getListLength(matrix1RowObj)];
+ }
+
+ HiveUtils.toDoubleArray(matrix0RowObj, matrix0RowOI, matrix0ElOI, matrix0Row, false);
+ HiveUtils.toDoubleArray(matrix1RowObj, matrix1RowOI, matrix1ElOI, matrix1Row, false);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(matrix0Row.length, matrix1Row.length);
+ }
+
+ for (int i = 0; i < matrix0Row.length; i++) {
+ for (int j = 0; j < matrix1Row.length; j++) {
+ myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j];
+ }
+ }
+ }
+
+ @Override
+ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object other)
+ throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ final List<?> matrix = aggMatrixOI.getList(other);
+ final int n = matrix.size();
+ final double[] row = new double[aggMatrixRowOI.getListLength(matrix.get(0))];
+ for (int i = 0; i < n; i++) {
+ HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI, row, false);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(n, row.length);
+ }
+
+ for (int j = 0; j < row.length; j++) {
+ myAgg.aggMatrix[i][j] += row[j];
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ return terminate(agg);
+ }
+
+ @Override
+ public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
+ throws HiveException {
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ final List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
+ for (double[] row : myAgg.aggMatrix) {
+ result.add(WritableUtils.toWritableList(row));
+ }
+ return result;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index d8b1aef..8188b7a 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -59,6 +59,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
@@ -200,8 +201,7 @@ public final class HiveUtils {
return BOOLEAN_TYPE_NAME.equals(typeName);
}
- public static boolean isNumberOI(@Nonnull final ObjectInspector argOI)
- throws UDFArgumentTypeException {
+ public static boolean isNumberOI(@Nonnull final ObjectInspector argOI) {
if (argOI.getCategory() != Category.PRIMITIVE) {
return false;
}
@@ -246,6 +246,16 @@ public final class HiveUtils {
return oi.getCategory() == Category.MAP;
}
+ public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
+ public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
@@ -687,6 +697,14 @@ public final class HiveUtils {
return (LongObjectInspector) argOI;
}
+ public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector argOI)
+ throws UDFArgumentException {
+ if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) {
+ throw new UDFArgumentException("Argument type must be DOUBLE: " + argOI.getTypeName());
+ }
+ return (DoubleObjectInspector) argOI;
+ }
+
public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentTypeException {
if (argOI.getCategory() != Category.PRIMITIVE) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
index a4f2691..a9c7390 100644
--- a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java
@@ -25,7 +25,9 @@ import java.util.List;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
@@ -142,6 +144,20 @@ public final class WritableUtils {
return list;
}
+ @Nonnull
+ public static List<DoubleWritable> toWritableList(@Nonnull final double[] src,
+ @Nullable List<DoubleWritable> list) throws UDFArgumentException {
+ if (list == null) {
+ return toWritableList(src);
+ }
+
+ Preconditions.checkArgument(src.length == list.size(), UDFArgumentException.class);
+ for (int i = 0; i < src.length; i++) {
+ list.set(i, new DoubleWritable(src[i]));
+ }
+ return list;
+ }
+
public static Text val(final String v) {
return new Text(v);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/lang/Preconditions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/lang/Preconditions.java b/core/src/main/java/hivemall/utils/lang/Preconditions.java
index 4fa2bdd..9f76bd6 100644
--- a/core/src/main/java/hivemall/utils/lang/Preconditions.java
+++ b/core/src/main/java/hivemall/utils/lang/Preconditions.java
@@ -18,6 +18,7 @@
*/
package hivemall.utils.lang;
+import javax.annotation.Nonnull;
import javax.annotation.Nullable;
public final class Preconditions {
@@ -31,6 +32,21 @@ public final class Preconditions {
return reference;
}
+ public static <T, E extends Throwable> T checkNotNull(@Nullable T reference,
+ @Nonnull Class<E> clazz) throws E {
+ if (reference == null) {
+ final E throwable;
+ try {
+ throwable = clazz.newInstance();
+ } catch (InstantiationException | IllegalAccessException e) {
+ throw new IllegalStateException(
+ "Failed to instantiate a class: " + clazz.getName(), e);
+ }
+ throw throwable;
+ }
+ return reference;
+ }
+
public static <T> T checkNotNull(T reference, @Nullable Object errorMessage) {
if (reference == null) {
throw new NullPointerException(String.valueOf(errorMessage));
@@ -44,6 +60,20 @@ public final class Preconditions {
}
}
+ public static <E extends Throwable> void checkArgument(boolean expression,
+ @Nonnull Class<E> clazz) throws E {
+ if (!expression) {
+ final E throwable;
+ try {
+ throwable = clazz.newInstance();
+ } catch (InstantiationException | IllegalAccessException e) {
+ throw new IllegalStateException(
+ "Failed to instantiate a class: " + clazz.getName(), e);
+ }
+ throw throwable;
+ }
+ }
+
public static void checkArgument(boolean expression, @Nullable Object errorMessage) {
if (!expression) {
throw new IllegalArgumentException(String.valueOf(errorMessage));
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index 812f619..599bf51 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -22,11 +22,19 @@ import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+
+import java.util.AbstractMap;
+import java.util.Map;
public final class StatsUtils {
@@ -189,4 +197,87 @@ public final class StatsUtils {
return 1.d - numerator / denominator;
}
+ /**
+ * @param observed means non-negative vector
+ * @param expected means positive vector
+ * @return chi2 value
+ */
+ public static double chiSquare(@Nonnull final double[] observed,
+ @Nonnull final double[] expected) {
+ if (observed.length < 2) {
+ throw new DimensionMismatchException(observed.length, 2);
+ }
+ if (expected.length != observed.length) {
+ throw new DimensionMismatchException(observed.length, expected.length);
+ }
+ MathArrays.checkPositive(expected);
+ for (double d : observed) {
+ if (d < 0.d) {
+ throw new NotPositiveException(d);
+ }
+ }
+
+ double sumObserved = 0.d;
+ double sumExpected = 0.d;
+ for (int i = 0; i < observed.length; i++) {
+ sumObserved += observed[i];
+ sumExpected += expected[i];
+ }
+ double ratio = 1.d;
+ boolean rescale = false;
+ if (FastMath.abs(sumObserved - sumExpected) > 10e-6) {
+ ratio = sumObserved / sumExpected;
+ rescale = true;
+ }
+ double sumSq = 0.d;
+ for (int i = 0; i < observed.length; i++) {
+ if (rescale) {
+ final double dev = observed[i] - ratio * expected[i];
+ sumSq += dev * dev / (ratio * expected[i]);
+ } else {
+ final double dev = observed[i] - expected[i];
+ sumSq += dev * dev / expected[i];
+ }
+ }
+ return sumSq;
+ }
+
+ /**
+ * @param observed means non-negative vector
+ * @param expected means positive vector
+ * @return p value
+ */
+ public static double chiSquareTest(@Nonnull final double[] observed,
+ @Nonnull final double[] expected) {
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(
+ expected.length - 1.d);
+ return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected));
+ }
+
+ /**
+ * This method offers effective calculation for multiple entries rather than calculation
+ * individually
+ *
+ * @param observeds means non-negative matrix
+ * @param expecteds means positive matrix
+ * @return (chi2 value[], p value[])
+ */
+ public static Map.Entry<double[], double[]> chiSquare(@Nonnull final double[][] observeds,
+ @Nonnull final double[][] expecteds) {
+ Preconditions.checkArgument(observeds.length == expecteds.length);
+
+ final int len = expecteds.length;
+ final int lenOfEach = expecteds[0].length;
+
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d);
+
+ final double[] chi2s = new double[len];
+ final double[] ps = new double[len];
+ for (int i = 0; i < len; i++) {
+ chi2s[i] = chiSquare(observeds[i], expecteds[i]);
+ ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]);
+ }
+
+ return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps);
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
new file mode 100644
index 0000000..fd742bb
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ChiSquareUDFTest {
+
+ @Test
+ public void testIris() throws Exception {
+ final ChiSquareUDF chi2 = new ChiSquareUDF();
+ final List<List<DoubleWritable>> observed = new ArrayList<List<DoubleWritable>>();
+ final List<List<DoubleWritable>> expected = new ArrayList<List<DoubleWritable>>();
+ final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] {
+ new GenericUDF.DeferredJavaObject(observed),
+ new GenericUDF.DeferredJavaObject(expected)};
+
+ final double[][] matrix0 = new double[][] {
+ {250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996},
+ {296.8, 138.50000000000003, 212.99999999999997, 66.3},
+ {329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998}};
+ final double[][] matrix1 = new double[][] {
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589},
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589},
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}};
+
+ for (double[] row : matrix0) {
+ observed.add(WritableUtils.toWritableList(row));
+ }
+ for (double[] row : matrix1) {
+ expected.add(WritableUtils.toWritableList(row));
+ }
+
+ chi2.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)),
+ ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))});
+ final List<DoubleWritable>[] result = chi2.evaluate(dObjs);
+ final double[] result0 = new double[matrix0[0].length];
+ final double[] result1 = new double[matrix0[0].length];
+ for (int i = 0; i < result0.length; i++) {
+ result0[i] = result[0].get(i).get();
+ result1[i] = result[1].get(i).get();
+ }
+
+ // compare results to one of scikit-learn
+ final double[] answer0 = new double[] {10.81782088, 3.59449902, 116.16984746, 67.24482759};
+ final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 5.94344354e-26,
+ 2.50017968e-15};
+
+ Assert.assertArrayEquals(answer0, result0, 1e-5);
+ Assert.assertArrayEquals(answer1, result1, 1e-5);
+ chi2.close();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
new file mode 100644
index 0000000..79570e3
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SignalNoiseRatioUDAFTest {
+
+ @Test
+ public void snrBinaryClass() throws Exception {
+ // this test is based on *subset* of iris data set
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {4.7, 3.2, 1.3, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5},
+ {6.9, 3.1, 4.9, 1.5}};
+
+ final int[][] labels = new int[][] { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}};
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ // compare with result by numpy
+ final double[] answer = new double[] {4.38425236, 0.26390002, 15.83984511, 26.87005769};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
+ public void snrMultipleClassNormalCase() throws Exception {
+ // this test is based on *subset* of iris data set
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1},
+ {0, 0, 1}};
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ // compare with result by scikit-learn
+ final double[] answer = new double[] {8.43181818, 1.32121212, 42.94949495, 33.80952381};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
+ public void snrMultipleClassCornerCase0() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ // all c0[0] and c1[0] are equal
+ // all c1[1] and c2[1] are equal
+ // all c*[2] are equal
+ // all c*[3] are different
+ final double[][] features = new double[][] { {3.5, 1.4, 0.3, 5.1}, {3.5, 1.5, 0.3, 5.2},
+ {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 6.3}};
+
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, // class `0`
+ {0, 1, 0}, {0, 1, 0}, // class `1`
+ {0, 0, 1}}; // class `2`, only single entry
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ final double[] answer = new double[] {Double.POSITIVE_INFINITY, 121.99999999999989, 0.d,
+ 28.761904761904734};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
+ public void snrMultipleClassCornerCase1() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.3, 3.3, 6.d, 2.5}, {6.4, 3.2, 4.5, 1.5}};
+
+ // has multiple single entries
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {1, 0, 0}, // class `0`
+ {0, 1, 0}, // class `1`, only single entry
+ {0, 0, 1}}; // class `2`, only single entry
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final List<Double> result = new ArrayList<Double>();
+ for (DoubleWritable dw : resultObj) {
+ result.add(dw.get());
+ }
+
+ Assert.assertFalse(result.contains(Double.POSITIVE_INFINITY));
+ }
+
+ @Test
+ public void snrMultipleClassCornerCase2() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ // all [0] are equal
+ // all [1] are equal *each class*
+ final double[][] features = new double[][] { {1.d, 1.d, 1.4, 0.2}, {1.d, 1.d, 1.4, 0.2},
+ {1.d, 2.d, 4.7, 1.4}, {1.d, 2.d, 4.5, 1.5}, {1.d, 3.d, 6.d, 2.5},
+ {1.d, 3.d, 5.1, 1.9}};
+
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1},
+ {0, 0, 1}};
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ final double[] answer = new double[] {0.d, Double.POSITIVE_INFINITY, 42.94949495,
+ 33.80952381};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test(expected = UDFArgumentException.class)
+ public void shouldFail0() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {0, 0, 0}, // cause UDFArgumentException
+ {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test(expected = UDFArgumentException.class)
+ public void shouldFail1() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2},
+ {4.9, 3.d, 1.4}, // cause IllegalArgumentException
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test(expected = UDFArgumentException.class)
+ public void shouldFail2() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; // cause IllegalArgumentException
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
new file mode 100644
index 0000000..3e3fc12
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.List;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SelectKBeatUDFTest {
+
+ @Test
+ public void test() throws Exception {
+ final SelectKBestUDF selectKBest = new SelectKBestUDF();
+ final int k = 2;
+ final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2,
+ 12.199999999999996};
+ final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467,
+ 187.93333893418327, 59.93333511948589};
+
+ final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] {
+ new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)),
+ new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)),
+ new GenericUDF.DeferredJavaObject(k)};
+
+ selectKBest.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)});
+ final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs);
+
+ Assert.assertEquals(resultObj.size(), k);
+
+ final double[] result = new double[k];
+ for (int i = 0; i < k; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ final double[] answer = new double[] {250.29999999999998, 73.2};
+
+ Assert.assertArrayEquals(answer, result, 0.d);
+ selectKBest.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/fad2941f/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
new file mode 100644
index 0000000..f705a89
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.WritableUtils;
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TransposeAndDotUDAFTest {
+
+ @Test
+ public void test() throws Exception {
+ final TransposeAndDotUDAF tad = new TransposeAndDotUDAF();
+
+ final double[][] matrix0 = new double[][] { {1, -2}, {-1, 3}};
+ final double[][] matrix1 = new double[][] { {1, 2}, {3, 4}};
+
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)};
+ final GenericUDAFEvaluator evaluator = tad.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer agg = (TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+ for (int i = 0; i < matrix0.length; i++) {
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(matrix0[i]),
+ WritableUtils.toWritableList(matrix1[i])});
+ }
+
+ final double[][] answer = new double[][] { {-2.0, -2.0}, {7.0, 8.0}};
+
+ for (int i = 0; i < answer.length; i++) {
+ Assert.assertArrayEquals(answer[i], agg.aggMatrix[i], 0.d);
+ }
+ }
+}