You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 07:04:25 UTC
[23/50] [abbrv] incubator-hivemall git commit: add snr
add snr
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/22a608ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/22a608ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/22a608ee
Branch: refs/heads/JIRA-22/pr-385
Commit: 22a608ee1c7239b2953183b5341f80c58b1e7045
Parents: 5088ef3
Author: amaya <gi...@sapphire.in.net>
Authored: Mon Sep 26 17:07:55 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Mon Sep 26 17:15:22 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/SignalNoiseRatioUDAF.java | 327 +++++++++++++++++++
.../selection/SignalNoiseRatioUDAFTest.java | 174 ++++++++++
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.hive | 3 +
resources/ddl/define-all.spark | 3 +
resources/ddl/define-udfs.td.hql | 1 +
6 files changed, 511 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
new file mode 100644
index 0000000..b7b9126
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -0,0 +1,327 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label)"
+ + " - Returns SNR values of each feature as array<double>")
+public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+ throws SemanticException {
+ final ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `features`");
+ }
+
+ if (!HiveUtils.isListOI(OIs[1])
+ || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<int> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `labels`");
+ }
+
+ return new SignalNoiseRatioUDAFEvaluator();
+ }
+
+ static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector featuresOI;
+ private PrimitiveObjectInspector featureOI;
+ private ListObjectInspector labelsOI;
+ private PrimitiveObjectInspector labelOI;
+
+ // PARTIAL2 and FINAL
+ private StructObjectInspector structOI;
+ private StructField nsField, meanssField, variancessField;
+ private ListObjectInspector nsOI;
+ private LongObjectInspector nOI;
+ private ListObjectInspector meanssOI;
+ private ListObjectInspector meansOI;
+ private DoubleObjectInspector meanOI;
+ private ListObjectInspector variancessOI;
+ private ListObjectInspector variancesOI;
+ private DoubleObjectInspector varianceOI;
+
+ @AggregationType(estimable = true)
+ static class SignalNoiseRatioAggregationBuffer extends AbstractAggregationBuffer {
+ long[] ns;
+ double[][] meanss;
+ double[][] variancess;
+
+ @Override
+ public int estimate() {
+ return ns == null ? 0 : 8 * ns.length + 8 * meanss.length * meanss[0].length + 8
+ * variancess.length * variancess[0].length;
+ }
+
+ public void init(int nClasses, int nFeatures) {
+ ns = new long[nClasses];
+ meanss = new double[nClasses][nFeatures];
+ variancess = new double[nClasses][nFeatures];
+ }
+
+ public void reset() {
+ if (ns != null) {
+ Arrays.fill(ns, 0);
+ for (double[] means : meanss) {
+ Arrays.fill(means, 0.d);
+ }
+ for (double[] variances : variancess) {
+ Arrays.fill(variances, 0.d);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+ featuresOI = HiveUtils.asListOI(OIs[0]);
+ featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+ labelsOI = HiveUtils.asListOI(OIs[1]);
+ labelOI = HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
+ } else {
+ structOI = (StructObjectInspector) OIs[0];
+ nsField = structOI.getStructFieldRef("ns");
+ nsOI = HiveUtils.asListOI(nsField.getFieldObjectInspector());
+ nOI = HiveUtils.asLongOI(nsOI.getListElementObjectInspector());
+ meanssField = structOI.getStructFieldRef("meanss");
+ meanssOI = HiveUtils.asListOI(meanssField.getFieldObjectInspector());
+ meansOI = HiveUtils.asListOI(meanssOI.getListElementObjectInspector());
+ meanOI = HiveUtils.asDoubleOI(meansOI.getListElementObjectInspector());
+ variancessField = structOI.getStructFieldRef("variancess");
+ variancessOI = HiveUtils.asListOI(variancessField.getFieldObjectInspector());
+ variancesOI = HiveUtils.asListOI(variancessOI.getListElementObjectInspector());
+ varianceOI = HiveUtils.asDoubleOI(variancesOI.getListElementObjectInspector());
+ }
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ return ObjectInspectorFactory.getStandardStructObjectInspector(
+ Arrays.asList("ns", "meanss", "variancess"), fieldOIs);
+ } else {
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ final Object featuresObj = parameters[0];
+ final Object labelsObj = parameters[1];
+
+ Preconditions.checkNotNull(featuresObj);
+ Preconditions.checkNotNull(labelsObj);
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ // read class
+ final List labels = labelsOI.getList(labelsObj);
+ final int nClasses = labels.size();
+
+ // to calc SNR between classes
+ Preconditions.checkArgument(nClasses >= 2);
+
+ int clazz = -1;
+ for (int i = 0; i < nClasses; i++) {
+ int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+ if (label == 1 && clazz == -1) {
+ clazz = i;
+ } else if (label == 1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array. Multiple hot elements found.");
+ }
+ }
+ if (clazz == -1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array. Hot element not found.");
+ }
+
+ final List features = featuresOI.getList(featuresObj);
+ final int nFeatures = features.size();
+
+ Preconditions.checkArgument(nFeatures >= 1);
+
+ if (myAgg.ns == null) {
+ // init
+ myAgg.init(nClasses, nFeatures);
+ } else {
+ Preconditions.checkArgument(nClasses == myAgg.ns.length);
+ Preconditions.checkArgument(nFeatures == myAgg.meanss[0].length);
+ }
+
+ // calc incrementally
+ final long n = myAgg.ns[clazz];
+ myAgg.ns[clazz]++;
+ for (int i = 0; i < nFeatures; i++) {
+ final double x = PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
+ final double meanN = myAgg.meanss[clazz][i];
+ final double varianceN = myAgg.variancess[clazz][i];
+ myAgg.meanss[clazz][i] = (n * meanN + x) / (n + 1.d);
+ myAgg.variancess[clazz][i] = (n * varianceN + (x - meanN)
+ * (x - myAgg.meanss[clazz][i]))
+ / (n + 1.d);
+ }
+ }
+
+ @Override
+ public void merge(AggregationBuffer agg, Object other) throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final List ns = nsOI.getList(structOI.getStructFieldData(other, nsField));
+ final List meanss = meanssOI.getList(structOI.getStructFieldData(other, meanssField));
+ final List variancess = variancessOI.getList(structOI.getStructFieldData(other,
+ variancessField));
+
+ final int nClasses = ns.size();
+ final int nFeatures = meansOI.getListLength(meanss.get(0));
+ if (myAgg.ns == null) {
+ // init
+ myAgg.init(nClasses, nFeatures);
+ }
+ for (int i = 0; i < nClasses; i++) {
+ final long n = myAgg.ns[i];
+ final long m = PrimitiveObjectInspectorUtils.getLong(ns.get(i), nOI);
+ final List means = meansOI.getList(meanss.get(i));
+ final List variances = variancesOI.getList(variancess.get(i));
+
+ myAgg.ns[i] += m;
+ for (int j = 0; j < nFeatures; j++) {
+ final double meanN = myAgg.meanss[i][j];
+ final double meanM = PrimitiveObjectInspectorUtils.getDouble(means.get(j),
+ meanOI);
+ final double varianceN = myAgg.variancess[i][j];
+ final double varianceM = PrimitiveObjectInspectorUtils.getDouble(
+ variances.get(j), varianceOI);
+ myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m);
+ myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + FastMath.pow(
+ meanN - meanM, 2) * n * m / (n + m))
+ / (n + m - 1);
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final Object[] partialResult = new Object[3];
+ partialResult[0] = WritableUtils.toWritableList(myAgg.ns);
+ final List<List<DoubleWritable>> meanss = new ArrayList<List<DoubleWritable>>();
+ for (double[] means : myAgg.meanss) {
+ meanss.add(WritableUtils.toWritableList(means));
+ }
+ partialResult[1] = meanss;
+ final List<List<DoubleWritable>> variancess = new ArrayList<List<DoubleWritable>>();
+ for (double[] variances : myAgg.variancess) {
+ variancess.add(WritableUtils.toWritableList(variances));
+ }
+ partialResult[2] = variancess;
+ return partialResult;
+ }
+
+ @Override
+ public Object terminate(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final int nClasses = myAgg.ns.length;
+ final int nFeatures = myAgg.meanss[0].length;
+
+ // calc SNR between classes each feature
+ final double[] result = new double[nFeatures];
+ final double[] sds = new double[nClasses]; // memo
+ for (int i = 0; i < nFeatures; i++) {
+ sds[0] = FastMath.sqrt(myAgg.variancess[0][i]);
+ for (int j = 1; j < nClasses; j++) {
+ sds[j] = FastMath.sqrt(myAgg.variancess[j][i]);
+ if (Double.isNaN(sds[j])) {
+ continue;
+ }
+ for (int k = 0; k < j; k++) {
+ if (Double.isNaN(sds[k])) {
+ continue;
+ }
+ result[i] += FastMath.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i])
+ / (sds[j] + sds[k]);
+ }
+ }
+ }
+
+ // SUM(snr) GROUP BY feature
+ return WritableUtils.toWritableList(result);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
new file mode 100644
index 0000000..4655545
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -0,0 +1,174 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class SignalNoiseRatioUDAFTest {
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void test() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+ final double[] answer = new double[] {8.431818181818192, 1.3212121212121217,
+ 42.94949494949499, 33.80952380952378};
+ Assert.assertArrayEquals(answer, result, 0.d);
+ }
+
+ @Test
+ public void shouldFail0() throws Exception {
+ expectedException.expect(UDFArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {0, 0, 0}, // cause UDFArgumentException
+ {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test
+ public void shouldFail1() throws Exception {
+ expectedException.expect(IllegalArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2},
+ {4.9, 3.d, 1.4}, // cause IllegalArgumentException
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test
+ public void shouldFail2() throws Exception {
+ expectedException.expect(IllegalArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; // cause IllegalArgumentException
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index b515b24..10e72b7 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -209,6 +209,9 @@ CREATE FUNCTION l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF' USIN
DROP FUNCTION IF EXISTS chi2;
CREATE FUNCTION chi2 as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS snr;
+CREATE FUNCTION snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF' USING JAR '${hivemall_jar}';
+
--------------------
-- misc functions --
--------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 2124892..04b519e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -205,6 +205,9 @@ create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2Normalizatio
drop temporary function chi2;
create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+drop temporary function snr;
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
+
-----------------------------------
-- Feature engineering functions --
-----------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 47f0ce5..65c2346 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -190,6 +190,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 'hivemall.ftvec.scaling.L
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi2")
sqlContext.sql("CREATE TEMPORARY FUNCTION chi2 AS 'hivemall.ftvec.selection.ChiSquareUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS snr")
+sqlContext.sql("CREATE TEMPORARY FUNCTION snr AS 'hivemall.ftvec.selection.SignalNoiseRatioUDAF'")
+
/**
* misc functions
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index fd7dc1d..7aa537a 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -51,6 +51,7 @@ create temporary function rescale as 'hivemall.ftvec.scaling.RescaleUDF';
create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifierUDTF';
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';