You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2016/12/02 08:01:51 UTC
[01/50] [abbrv] incubator-hivemall git commit: add
HiveUtils.asDoubleOI
Repository: incubator-hivemall
Updated Branches:
refs/heads/JIRA-22/pr-285 [created] 9ca8bce75
refs/heads/JIRA-22/pr-304 [created] b0a0179b0
refs/heads/JIRA-22/pr-336 [created] 075f93485
refs/heads/JIRA-22/pr-356 [created] cc3443515
refs/heads/JIRA-22/pr-385 [created] 4c8dcbfcd
add HiveUtils.asDoubleOI
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/56adf2d4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/56adf2d4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/56adf2d4
Branch: refs/heads/JIRA-22/pr-385
Commit: 56adf2d4e8b2591c31b846b8980016d3dafdbacc
Parents: 2dc176a
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 16 15:48:33 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 16 15:48:33 2016 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/utils/hadoop/HiveUtils.java | 9 +++++++++
1 file changed, 9 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/56adf2d4/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 32b60d0..7e8ea7b 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -57,6 +57,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
@@ -675,6 +676,14 @@ public final class HiveUtils {
return (LongObjectInspector) argOI;
}
+ public static DoubleObjectInspector asDoubleOI(@Nonnull final ObjectInspector argOI)
+ throws UDFArgumentException {
+ if (!DOUBLE_TYPE_NAME.equals(argOI.getTypeName())) {
+ throw new UDFArgumentException("Argument type must be DOUBLE: " + argOI.getTypeName());
+ }
+ return (DoubleObjectInspector) argOI;
+ }
+
public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectInspector argOI)
throws UDFArgumentTypeException {
if (argOI.getCategory() != Category.PRIMITIVE) {
[28/50] [abbrv] incubator-hivemall git commit: refine tests
Posted by my...@apache.org.
refine tests
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8e2842cf
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8e2842cf
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8e2842cf
Branch: refs/heads/JIRA-22/pr-385
Commit: 8e2842cf8c272642feaa76bf95e8fa463b0322dc
Parents: 1347de9
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 28 14:24:19 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 28 14:24:19 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/ChiSquareUDFTest.java | 12 ++--
.../selection/SignalNoiseRatioUDAFTest.java | 71 ++++++++++++++++----
2 files changed, 64 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8e2842cf/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
index 38f7f57..d5880b8 100644
--- a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
+++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
@@ -69,12 +69,12 @@ public class ChiSquareUDFTest {
result1[i] = Double.valueOf(((List) result[1]).get(i).toString());
}
- final double[] answer0 = new double[] {10.817820878493995, 3.5944990176817315,
- 116.16984746363957, 67.24482558215503};
- final double[] answer1 = new double[] {0.004476514990225833, 0.16575416718561453, 0.d,
- 2.55351295663786e-15};
+ // compare with results by scikit-learn
+ final double[] answer0 = new double[] {10.81782088, 3.59449902, 116.16984746, 67.24482759};
+ final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 5.94344354e-26,
+ 2.50017968e-15};
- Assert.assertArrayEquals(answer0, result0, 0.d);
- Assert.assertArrayEquals(answer1, result1, 0.d);
+ Assert.assertArrayEquals(answer0, result0, 1e-5);
+ Assert.assertArrayEquals(answer1, result1, 1e-5);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8e2842cf/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
index 4655545..56a01d0 100644
--- a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -40,7 +40,8 @@ public class SignalNoiseRatioUDAFTest {
public ExpectedException expectedException = ExpectedException.none();
@Test
- public void test() throws Exception {
+ public void snrBinaryClass() throws Exception {
+ // this test is based on *subset* of iris data set
final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
final ObjectInspector[] OIs = new ObjectInspector[] {
ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
@@ -51,20 +52,62 @@ public class SignalNoiseRatioUDAFTest {
final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
evaluator.reset(agg);
- final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {4.7, 3.2, 1.3, 0.2}, {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5},
+ {6.9, 3.1, 4.9, 1.5}};
+
+ final int[][] labels = new int[][] { {1, 0}, {1, 0}, {1, 0}, {0, 1}, {0, 1}, {0, 1}};
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ // compare with result by numpy
+ final double[] answer = new double[] {4.38425236, 0.26390002, 15.83984511, 26.87005769};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
+ public void snrMultipleClass() throws Exception {
+ // this test is based on *subset* of iris data set
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
{7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
{5.8, 2.7, 5.1, 1.9}};
- final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
- {0, 0, 1}, {0, 0, 1}};
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1},
+ {0, 0, 1}};
- for (int i = 0; i < featuress.length; i++) {
- final List<IntWritable> labels = new ArrayList<IntWritable>();
- for (int label : labelss[i]) {
- labels.add(new IntWritable(label));
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
}
- evaluator.iterate(agg,
- new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
}
@SuppressWarnings("unchecked")
@@ -74,9 +117,11 @@ public class SignalNoiseRatioUDAFTest {
for (int i = 0; i < size; i++) {
result[i] = resultObj.get(i).get();
}
- final double[] answer = new double[] {8.431818181818192, 1.3212121212121217,
- 42.94949494949499, 33.80952380952378};
- Assert.assertArrayEquals(answer, result, 0.d);
+
+ // compare with result by scikit-learn
+ final double[] answer = new double[] {8.43181818, 1.32121212, 42.94949495, 33.80952381};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
}
@Test
[09/50] [abbrv] incubator-hivemall git commit: add license and format
Posted by my...@apache.org.
add license and format
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ad81b3aa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ad81b3aa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ad81b3aa
Branch: refs/heads/JIRA-22/pr-385
Commit: ad81b3aa5a0bbb7c248d127ba44608578c01ae00
Parents: 1ab9b09
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 17:05:55 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:37:51 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 92 ++++++++++++--------
.../tools/array/ArrayTopKIndicesUDF.java | 29 ++++--
.../tools/array/SubarrayByIndicesUDF.java | 36 ++++++--
.../tools/matrix/TransposeAndDotUDAF.java | 64 +++++++++-----
.../java/hivemall/utils/hadoop/HiveUtils.java | 10 ++-
.../java/hivemall/utils/math/StatsUtils.java | 29 +++---
6 files changed, 171 insertions(+), 89 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index 1954e33..e2b7494 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -1,3 +1,21 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package hivemall.ftvec.selection;
import hivemall.utils.hadoop.HiveUtils;
@@ -10,24 +28,20 @@ import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructField;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@Description(name = "chi2",
- value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)" +
- " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
+ value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)"
+ + " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
public class ChiSquareUDF extends GenericUDF {
private ListObjectInspector observedOI;
private ListObjectInspector observedRowOI;
@@ -42,31 +56,31 @@ public class ChiSquareUDF extends GenericUDF {
throw new UDFArgumentLengthException("Specify two arguments.");
}
- if (!HiveUtils.isNumberListListOI(OIs[0])){
- throw new UDFArgumentTypeException(0, "Only array<array<number>> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `observed`");
+ if (!HiveUtils.isNumberListListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<array<number>> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `observed`");
}
- if (!HiveUtils.isNumberListListOI(OIs[1])){
- throw new UDFArgumentTypeException(1, "Only array<array<number>> type argument is acceptable but "
- + OIs[1].getTypeName() + " was passed as `expected`");
+ if (!HiveUtils.isNumberListListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<array<number>> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `expected`");
}
observedOI = HiveUtils.asListOI(OIs[1]);
- observedRowOI=HiveUtils.asListOI(observedOI.getListElementObjectInspector());
- observedElOI = HiveUtils.asDoubleCompatibleOI( observedRowOI.getListElementObjectInspector());
- expectedOI = HiveUtils.asListOI(OIs[0]);
- expectedRowOI=HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
+ observedRowOI = HiveUtils.asListOI(observedOI.getListElementObjectInspector());
+ observedElOI = HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector());
+ expectedOI = HiveUtils.asListOI(OIs[0]);
+ expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
- fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(
- Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
+ Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
}
@Override
@@ -76,40 +90,44 @@ public class ChiSquareUDF extends GenericUDF {
Preconditions.checkNotNull(observedObj);
Preconditions.checkNotNull(expectedObj);
- final int nClasses = observedObj.size();
+ final int nClasses = observedObj.size();
Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
- int nFeatures=-1;
- double[] observedRow=null; // to reuse
- double[] expectedRow=null; // to reuse
- double[][] observed =null; // shape = (#features, #classes)
+ int nFeatures = -1;
+ double[] observedRow = null; // to reuse
+ double[] expectedRow = null; // to reuse
+ double[][] observed = null; // shape = (#features, #classes)
double[][] expected = null; // shape = (#features, #classes)
// explode and transpose matrix
- for(int i=0;i<nClasses;i++){
- if(i==0){
+ for (int i = 0; i < nClasses; i++) {
+ if (i == 0) {
// init
- observedRow=HiveUtils.asDoubleArray(observedObj.get(i),observedRowOI,observedElOI,false);
- expectedRow=HiveUtils.asDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI, false);
+ observedRow = HiveUtils.asDoubleArray(observedObj.get(i), observedRowOI,
+ observedElOI, false);
+ expectedRow = HiveUtils.asDoubleArray(expectedObj.get(i), expectedRowOI,
+ expectedElOI, false);
nFeatures = observedRow.length;
- observed=new double[nFeatures][nClasses];
+ observed = new double[nFeatures][nClasses];
expected = new double[nFeatures][nClasses];
- }else{
- HiveUtils.toDoubleArray(observedObj.get(i),observedRowOI,observedElOI,observedRow,false);
- HiveUtils.toDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI,expectedRow, false);
+ } else {
+ HiveUtils.toDoubleArray(observedObj.get(i), observedRowOI, observedElOI,
+ observedRow, false);
+ HiveUtils.toDoubleArray(expectedObj.get(i), expectedRowOI, expectedElOI,
+ expectedRow, false);
}
- for(int j=0;j<nFeatures;j++){
+ for (int j = 0; j < nFeatures; j++) {
observed[j][i] = observedRow[j];
expected[j][i] = expectedRow[j];
}
}
- final Map.Entry<double[],double[]> chi2 = StatsUtils.chiSquares(observed,expected);
+ final Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquares(observed, expected);
final Object[] result = new Object[2];
result[0] = WritableUtils.toWritableList(chi2.getKey());
- result[1]=WritableUtils.toWritableList(chi2.getValue());
+ result[1] = WritableUtils.toWritableList(chi2.getValue());
return result;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
index bf9fe15..f895f9b 100644
--- a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
+++ b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
@@ -1,3 +1,21 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package hivemall.tools.array;
import hivemall.utils.hadoop.HiveUtils;
@@ -22,7 +40,8 @@ import java.util.Comparator;
import java.util.List;
import java.util.Map;
-@Description(name = "array_top_k_indices",
+@Description(
+ name = "array_top_k_indices",
value = "_FUNC_(array<number> array, const int k) - Returns indices array of top-k as array<int>")
public class ArrayTopKIndicesUDF extends GenericUDF {
private ListObjectInspector arrayOI;
@@ -36,8 +55,9 @@ public class ArrayTopKIndicesUDF extends GenericUDF {
}
if (!HiveUtils.isNumberListOI(OIs[0])) {
- throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `array`");
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `array`");
}
if (!HiveUtils.isIntegerOI(OIs[1])) {
throw new UDFArgumentTypeException(1, "Only int type argument is acceptable but "
@@ -48,8 +68,7 @@ public class ArrayTopKIndicesUDF extends GenericUDF {
elementOI = HiveUtils.asDoubleCompatibleOI(arrayOI.getListElementObjectInspector());
kOI = HiveUtils.asIntegerOI(OIs[1]);
- return ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java b/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
index f476589..07e158a 100644
--- a/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
+++ b/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
@@ -1,6 +1,23 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package hivemall.tools.array;
-
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import org.apache.hadoop.hive.ql.exec.Description;
@@ -21,8 +38,8 @@ import java.util.ArrayList;
import java.util.List;
@Description(name = "subarray_by_indices",
- value = "_FUNC_(array<number> input, array<int> indices)" +
- " - Returns subarray selected by given indices as array<number>")
+ value = "_FUNC_(array<number> input, array<int> indices)"
+ + " - Returns subarray selected by given indices as array<number>")
public class SubarrayByIndicesUDF extends GenericUDF {
private ListObjectInspector inputOI;
private PrimitiveObjectInspector elementOI;
@@ -36,13 +53,15 @@ public class SubarrayByIndicesUDF extends GenericUDF {
}
if (!HiveUtils.isListOI(OIs[0])) {
- throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `input`");
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `input`");
}
if (!HiveUtils.isListOI(OIs[1])
|| !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
- throw new UDFArgumentTypeException(0, "Only array<int> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `indices`");
+ throw new UDFArgumentTypeException(0,
+ "Only array<int> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `indices`");
}
inputOI = HiveUtils.asListOI(OIs[0]);
@@ -50,8 +69,7 @@ public class SubarrayByIndicesUDF extends GenericUDF {
indicesOI = HiveUtils.asListOI(OIs[1]);
indexOI = HiveUtils.asIntegerOI(indicesOI.getListElementObjectInspector());
- return ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
index 3dcbb93..1e54004 100644
--- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -1,3 +1,21 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
package hivemall.tools.matrix;
import hivemall.utils.hadoop.HiveUtils;
@@ -23,12 +41,14 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-@Description(name = "transpose_and_dot",
- value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)" +
- " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
+@Description(
+ name = "transpose_and_dot",
+ value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)"
+ + " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
@Override
- public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+ throws SemanticException {
ObjectInspector[] OIs = info.getParameterObjectInspectors();
if (OIs.length != 2) {
@@ -36,13 +56,15 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
}
if (!HiveUtils.isNumberListOI(OIs[0])) {
- throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `matrix0_row`");
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `matrix0_row`");
}
if (!HiveUtils.isNumberListOI(OIs[1])) {
- throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
- + OIs[1].getTypeName() + " was passed as `matrix1_row`");
+ throw new UDFArgumentTypeException(1,
+ "Only array<number> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `matrix1_row`");
}
return new TransposeAndDotUDAFEvaluator();
@@ -69,9 +91,7 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
@Override
public int estimate() {
- return aggMatrix != null
- ? aggMatrix.length * aggMatrix[0].length * 8
- : 0;
+ return aggMatrix != null ? aggMatrix.length * aggMatrix[0].length * 8 : 0;
}
public void init(int n, int m) {
@@ -92,19 +112,17 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
super.init(mode, OIs);
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
- matrix0RowOI = HiveUtils.asListOI( OIs[0]);
+ matrix0RowOI = HiveUtils.asListOI(OIs[0]);
matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
matrix1RowOI = HiveUtils.asListOI(OIs[1]);
matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
} else {
- aggMatrixOI = HiveUtils.asListOI( OIs[0]);
- aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+ aggMatrixOI = HiveUtils.asListOI(OIs[0]);
+ aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
}
- return ObjectInspectorFactory.getStandardListObjectInspector(
- ObjectInspectorFactory.getStandardListObjectInspector(
- PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
}
@Override
@@ -124,11 +142,11 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
- if(matrix0Row==null){
- matrix0Row=new double[matrix0RowOI.getListLength(parameters[0])];
+ if (matrix0Row == null) {
+ matrix0Row = new double[matrix0RowOI.getListLength(parameters[0])];
}
- if(matrix1Row==null){
- matrix1Row=new double[matrix1RowOI.getListLength(parameters[1])];
+ if (matrix1Row == null) {
+ matrix1Row = new double[matrix1RowOI.getListLength(parameters[1])];
}
HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false);
@@ -158,9 +176,9 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
List matrix = aggMatrixOI.getList(other);
final int n = matrix.size();
- final double[] row =new double[ aggMatrixRowOI.getListLength(matrix.get(0))];
+ final double[] row = new double[aggMatrixRowOI.getListLength(matrix.get(0))];
for (int i = 0; i < n; i++) {
- HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI,row,false);
+ HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI, row, false);
if (myAgg.aggMatrix == null) {
myAgg.init(n, row.length);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index dcbf534..9272e60 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -1,7 +1,7 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
- * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2016 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -231,12 +231,14 @@ public final class HiveUtils {
return category == Category.LIST;
}
- public static boolean isNumberListOI(@Nonnull final ObjectInspector oi){
- return isListOI(oi) && isNumberOI(((ListObjectInspector)oi).getListElementObjectInspector());
+ public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector());
}
public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) {
- return isListOI(oi) && isNumberListOI(((ListObjectInspector)oi).getListElementObjectInspector());
+ return isListOI(oi)
+ && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
}
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ad81b3aa/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index f9d0f30..d3b25c7 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -1,7 +1,7 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
- * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2016 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -198,7 +198,8 @@ public final class StatsUtils {
* @param expected mean vector whose value is expected
* @return chi2 value
*/
- public static double chiSquare(@Nonnull final double[] observed, @Nonnull final double[] expected) {
+ public static double chiSquare(@Nonnull final double[] observed,
+ @Nonnull final double[] expected) {
Preconditions.checkArgument(observed.length == expected.length);
double sumObserved = 0.d;
@@ -237,32 +238,38 @@ public final class StatsUtils {
* @param expected means vector whose value is expected
* @return p value
*/
- public static double chiSquareTest(@Nonnull final double[] observed, @Nonnull final double[] expected) {
- ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.d);
- return 1.d - distribution.cumulativeProbability(chiSquare(observed,expected));
+ public static double chiSquareTest(@Nonnull final double[] observed,
+ @Nonnull final double[] expected) {
+ ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
+ (double) expected.length - 1.d);
+ return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected));
}
/**
- * This method offers effective calculation for multiple entries rather than calculation individually
+ * This method offers effective calculation for multiple entries rather than calculation
+ * individually
+ *
* @param observeds means matrix whose values are observed
* @param expecteds means matrix
* @return (chi2 value[], p value[])
*/
- public static Map.Entry<double[],double[]> chiSquares(@Nonnull final double[][] observeds, @Nonnull final double[][] expecteds){
+ public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds,
+ @Nonnull final double[][] expecteds) {
Preconditions.checkArgument(observeds.length == expecteds.length);
final int len = expecteds.length;
final int lenOfEach = expecteds[0].length;
- final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)lenOfEach - 1.d);
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
+ (double) lenOfEach - 1.d);
final double[] chi2s = new double[len];
final double[] ps = new double[len];
- for(int i=0;i<len;i++){
- chi2s[i] = chiSquare(observeds[i],expecteds[i]);
+ for (int i = 0; i < len; i++) {
+ chi2s[i] = chiSquare(observeds[i], expecteds[i]);
ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]);
}
- return new AbstractMap.SimpleEntry<double[], double[]>(chi2s,ps);
+ return new AbstractMap.SimpleEntry<double[], double[]>(chi2s, ps);
}
}
[42/50] [abbrv] incubator-hivemall git commit: Update license header
Posted by my...@apache.org.
Update license header
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/798ec6a7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/798ec6a7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/798ec6a7
Branch: refs/heads/JIRA-22/pr-336
Commit: 798ec6a73ca37d474137fc82db1c22a92521307d
Parents: ddd8dc2
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 04:27:59 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 04:27:59 2016 +0900
----------------------------------------------------------------------
systemtest/pom.xml | 2 ++
1 file changed, 2 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/798ec6a7/systemtest/pom.xml
----------------------------------------------------------------------
diff --git a/systemtest/pom.xml b/systemtest/pom.xml
index e7345af..0debee0 100644
--- a/systemtest/pom.xml
+++ b/systemtest/pom.xml
@@ -6,7 +6,9 @@
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
+
http://www.apache.org/licenses/LICENSE-2.0
+
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
[19/50] [abbrv] incubator-hivemall git commit: mod chi2 function name
Posted by my...@apache.org.
mod chi2 function name
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a882c5f9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a882c5f9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a882c5f9
Branch: refs/heads/JIRA-22/pr-385
Commit: a882c5f9f8067b911254dfc43d268de06a5490f9
Parents: b8cf396
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 16:00:36 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 16:23:47 2016 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java | 2 +-
core/src/main/java/hivemall/utils/math/StatsUtils.java | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a882c5f9/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index 70f0316..1583959 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -129,7 +129,7 @@ public class ChiSquareUDF extends GenericUDF {
}
}
- final Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquares(observed, expected);
+ final Map.Entry<double[], double[]> chi2 = StatsUtils.chiSquare(observed, expected);
final Object[] result = new Object[2];
result[0] = WritableUtils.toWritableList(chi2.getKey());
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a882c5f9/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index e255b84..14adbff 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -262,7 +262,7 @@ public final class StatsUtils {
* @param expecteds means positive matrix
* @return (chi2 value[], p value[])
*/
- public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds,
+ public static Map.Entry<double[], double[]> chiSquare(@Nonnull final double[][] observeds,
@Nonnull final double[][] expecteds) {
Preconditions.checkArgument(observeds.length == expecteds.length);
[44/50] [abbrv] incubator-hivemall git commit: Updated license headers
Posted by my...@apache.org.
Updated license headers
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e44a413e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e44a413e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e44a413e
Branch: refs/heads/JIRA-22/pr-385
Commit: e44a413e5fd4270af53895fceec27ccff3d63a73
Parents: 67ba963
Author: myui <yu...@gmail.com>
Authored: Mon Nov 21 19:02:27 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Mon Nov 21 19:02:27 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 77 ++++++++++----------
.../ftvec/selection/SignalNoiseRatioUDAF.java | 39 +++++-----
.../hivemall/tools/array/SelectKBestUDF.java | 48 ++++++------
.../tools/matrix/TransposeAndDotUDAF.java | 38 +++++-----
.../ftvec/selection/ChiSquareUDFTest.java | 35 ++++-----
.../selection/SignalNoiseRatioUDAFTest.java | 36 ++++-----
.../tools/array/SelectKBeatUDFTest.java | 33 +++++----
.../tools/matrix/TransposeAndDotUDAFTest.java | 29 ++++----
8 files changed, 171 insertions(+), 164 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index 1583959..91742bc 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.ftvec.selection;
@@ -22,11 +22,18 @@ import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.StatsUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -34,15 +41,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-
@Description(name = "chi2",
value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)"
+ " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
-public class ChiSquareUDF extends GenericUDF {
+@UDFType(deterministic = true, stateful = false)
+public final class ChiSquareUDF extends GenericUDF {
+
private ListObjectInspector observedOI;
private ListObjectInspector observedRowOI;
private PrimitiveObjectInspector observedElOI;
@@ -61,27 +65,25 @@ public class ChiSquareUDF extends GenericUDF {
if (OIs.length != 2) {
throw new UDFArgumentLengthException("Specify two arguments.");
}
-
if (!HiveUtils.isNumberListListOI(OIs[0])) {
throw new UDFArgumentTypeException(0,
"Only array<array<number>> type argument is acceptable but " + OIs[0].getTypeName()
+ " was passed as `observed`");
}
-
if (!HiveUtils.isNumberListListOI(OIs[1])) {
throw new UDFArgumentTypeException(1,
"Only array<array<number>> type argument is acceptable but " + OIs[1].getTypeName()
+ " was passed as `expected`");
}
- observedOI = HiveUtils.asListOI(OIs[1]);
- observedRowOI = HiveUtils.asListOI(observedOI.getListElementObjectInspector());
- observedElOI = HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector());
- expectedOI = HiveUtils.asListOI(OIs[0]);
- expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
- expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
+ this.observedOI = HiveUtils.asListOI(OIs[1]);
+ this.observedRowOI = HiveUtils.asListOI(observedOI.getListElementObjectInspector());
+ this.observedElOI = HiveUtils.asDoubleCompatibleOI(observedRowOI.getListElementObjectInspector());
+ this.expectedOI = HiveUtils.asListOI(OIs[0]);
+ this.expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
+ this.expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
- final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
@@ -90,25 +92,26 @@ public class ChiSquareUDF extends GenericUDF {
}
@Override
- public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
- List observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features)
- List expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features)
+ public Object[] evaluate(DeferredObject[] dObj) throws HiveException {
+ List<?> observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features)
+ List<?> expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features)
+
+ if (observedObj == null || expectedObj == null) {
+ return null;
+ }
- Preconditions.checkNotNull(observedObj);
- Preconditions.checkNotNull(expectedObj);
final int nClasses = observedObj.size();
Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
// explode and transpose matrix
for (int i = 0; i < nClasses; i++) {
- final Object observedObjRow = observedObj.get(i);
- final Object expectedObjRow = expectedObj.get(i);
+ Object observedObjRow = observedObj.get(i);
+ Object expectedObjRow = expectedObj.get(i);
Preconditions.checkNotNull(observedObjRow);
Preconditions.checkNotNull(expectedObjRow);
if (observedRow == null) {
- // init
observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI,
false);
expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI,
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
index 96fdc5b..1727d2e 100644
--- a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -1,26 +1,31 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.ftvec.selection;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -42,13 +47,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspect
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label)"
+ " - Returns SNR values of each feature as array<double>")
public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
throws SemanticException {
@@ -248,7 +250,6 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
final int nClasses = ns.size();
final int nFeatures = meansOI.getListLength(meanss.get(0));
if (myAgg.ns == null) {
- // init
myAgg.init(nClasses, nFeatures);
}
for (int i = 0; i < nClasses; i++) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
index bdab5bb..0a383eb 100644
--- a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
+++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
@@ -1,25 +1,33 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.tools.array;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -31,21 +39,9 @@ import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
@Description(name = "select_k_best",
value = "_FUNC_(array<number> array, const array<number> importance_list, const int k)"
+ " - Returns selected top-k elements as array<double>")
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
index 9df9305..5925a0c 100644
--- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -1,26 +1,31 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.tools.matrix;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
@@ -37,15 +42,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
@Description(
name = "transpose_and_dot",
value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)"
+ " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
throws SemanticException {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
index d5880b8..64e7693 100644
--- a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
+++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
@@ -1,24 +1,28 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.ftvec.selection;
import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -27,9 +31,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.junit.Assert;
import org.junit.Test;
-import java.util.ArrayList;
-import java.util.List;
-
public class ChiSquareUDFTest {
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
index a4744d9..ec08344 100644
--- a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -1,24 +1,28 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.ftvec.selection;
import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
@@ -32,10 +36,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
-import java.util.ArrayList;
-import java.util.List;
-
public class SignalNoiseRatioUDAFTest {
+
@Rule
public ExpectedException expectedException = ExpectedException.none();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
index b86db5c..da080ef 100644
--- a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
+++ b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
@@ -1,24 +1,27 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.tools.array;
import hivemall.utils.hadoop.WritableUtils;
+
+import java.util.List;
+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -27,8 +30,6 @@ import org.apache.hadoop.io.IntWritable;
import org.junit.Assert;
import org.junit.Test;
-import java.util.List;
-
public class SelectKBeatUDFTest {
@Test
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e44a413e/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
index 93c6ef1..f705a89 100644
--- a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
+++ b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
@@ -1,24 +1,25 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.tools.matrix;
import hivemall.utils.hadoop.WritableUtils;
+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
[34/50] [abbrv] incubator-hivemall git commit: Update license headers
Posted by my...@apache.org.
Update license headers
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/43ca0c86
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/43ca0c86
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/43ca0c86
Branch: refs/heads/JIRA-22/pr-336
Commit: 43ca0c86936f3ccc7f825db3c4f4ecaa48087917
Parents: faebaf9
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Nov 16 15:23:49 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Nov 16 15:23:49 2016 +0900
----------------------------------------------------------------------
systemtest/README.md | 18 +++++++++++++
systemtest/pom.xml | 17 +++++++++++-
.../java/com/klarna/hiverunner/Extractor.java | 28 ++++++++++----------
.../hivemall/systemtest/MsgpackConverter.java | 28 ++++++++++----------
.../exception/QueryExecutionException.java | 28 ++++++++++----------
.../systemtest/model/CreateTableHQ.java | 28 ++++++++++----------
.../hivemall/systemtest/model/DropTableHQ.java | 28 ++++++++++----------
.../main/java/hivemall/systemtest/model/HQ.java | 28 ++++++++++----------
.../java/hivemall/systemtest/model/HQBase.java | 28 ++++++++++----------
.../hivemall/systemtest/model/InsertHQ.java | 28 ++++++++++----------
.../java/hivemall/systemtest/model/RawHQ.java | 28 ++++++++++----------
.../java/hivemall/systemtest/model/TableHQ.java | 28 ++++++++++----------
.../hivemall/systemtest/model/TableListHQ.java | 28 ++++++++++----------
.../model/UploadFileAsNewTableHQ.java | 28 ++++++++++----------
.../hivemall/systemtest/model/UploadFileHQ.java | 28 ++++++++++----------
.../model/UploadFileToExistingHQ.java | 28 ++++++++++----------
.../model/lazy/LazyMatchingResource.java | 28 ++++++++++----------
.../systemtest/runner/HiveSystemTestRunner.java | 28 ++++++++++----------
.../systemtest/runner/SystemTestCommonInfo.java | 28 ++++++++++----------
.../systemtest/runner/SystemTestRunner.java | 28 ++++++++++----------
.../systemtest/runner/SystemTestTeam.java | 28 ++++++++++----------
.../systemtest/runner/TDSystemTestRunner.java | 28 ++++++++++----------
.../main/java/hivemall/systemtest/utils/IO.java | 28 ++++++++++----------
23 files changed, 328 insertions(+), 295 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/README.md
----------------------------------------------------------------------
diff --git a/systemtest/README.md b/systemtest/README.md
index 9d1442a..4fca0c3 100644
--- a/systemtest/README.md
+++ b/systemtest/README.md
@@ -1,3 +1,21 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
## Usage
### Initialization
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/pom.xml
----------------------------------------------------------------------
diff --git a/systemtest/pom.xml b/systemtest/pom.xml
index e59d2ce..e7345af 100644
--- a/systemtest/pom.xml
+++ b/systemtest/pom.xml
@@ -1,4 +1,19 @@
-<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/com/klarna/hiverunner/Extractor.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/com/klarna/hiverunner/Extractor.java b/systemtest/src/main/java/com/klarna/hiverunner/Extractor.java
index 99720f0..f7f372f 100644
--- a/systemtest/src/main/java/com/klarna/hiverunner/Extractor.java
+++ b/systemtest/src/main/java/com/klarna/hiverunner/Extractor.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package com.klarna.hiverunner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/MsgpackConverter.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/MsgpackConverter.java b/systemtest/src/main/java/hivemall/systemtest/MsgpackConverter.java
index c6383e3..b86c1cf 100644
--- a/systemtest/src/main/java/hivemall/systemtest/MsgpackConverter.java
+++ b/systemtest/src/main/java/hivemall/systemtest/MsgpackConverter.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/exception/QueryExecutionException.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/exception/QueryExecutionException.java b/systemtest/src/main/java/hivemall/systemtest/exception/QueryExecutionException.java
index e17a32e..c2b0034 100644
--- a/systemtest/src/main/java/hivemall/systemtest/exception/QueryExecutionException.java
+++ b/systemtest/src/main/java/hivemall/systemtest/exception/QueryExecutionException.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.exception;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/CreateTableHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/CreateTableHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/CreateTableHQ.java
index 40004b8..e0047a6 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/CreateTableHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/CreateTableHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/DropTableHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/DropTableHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/DropTableHQ.java
index 4e9fe23..c09ae24 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/DropTableHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/DropTableHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/HQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/HQ.java b/systemtest/src/main/java/hivemall/systemtest/model/HQ.java
index f847dfe..05933a4 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/HQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/HQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/HQBase.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/HQBase.java b/systemtest/src/main/java/hivemall/systemtest/model/HQBase.java
index 4212a5a..0008100 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/HQBase.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/HQBase.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/InsertHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/InsertHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/InsertHQ.java
index 30bf13b..3d69f13 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/InsertHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/InsertHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/RawHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/RawHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/RawHQ.java
index 7671ca1..1d6f020 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/RawHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/RawHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/TableHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/TableHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/TableHQ.java
index 8c98876..2f9f3c9 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/TableHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/TableHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/TableListHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/TableListHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/TableListHQ.java
index a4283bb..146adbd 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/TableListHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/TableListHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/UploadFileAsNewTableHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileAsNewTableHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileAsNewTableHQ.java
index d88976e..7091557 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileAsNewTableHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileAsNewTableHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/UploadFileHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileHQ.java
index 378521d..35b35c9 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/UploadFileToExistingHQ.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileToExistingHQ.java b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileToExistingHQ.java
index fc7873e..1f5c28d 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/UploadFileToExistingHQ.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/UploadFileToExistingHQ.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/model/lazy/LazyMatchingResource.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/model/lazy/LazyMatchingResource.java b/systemtest/src/main/java/hivemall/systemtest/model/lazy/LazyMatchingResource.java
index 3715a9e..16f14ea 100644
--- a/systemtest/src/main/java/hivemall/systemtest/model/lazy/LazyMatchingResource.java
+++ b/systemtest/src/main/java/hivemall/systemtest/model/lazy/LazyMatchingResource.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.model.lazy;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
index 09d242d..6b41855 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.runner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
index 5b55466..60292fa 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.runner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
index e3d9412..77091f2 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.runner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
index babbfbb..86065e4 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.runner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
index d646e7e..6d6c85b 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.runner;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/43ca0c86/systemtest/src/main/java/hivemall/systemtest/utils/IO.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/utils/IO.java b/systemtest/src/main/java/hivemall/systemtest/utils/IO.java
index d572ed0..0430945 100644
--- a/systemtest/src/main/java/hivemall/systemtest/utils/IO.java
+++ b/systemtest/src/main/java/hivemall/systemtest/utils/IO.java
@@ -1,20 +1,20 @@
/*
- * Hivemall: Hive scalable Machine Learning Library
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
package hivemall.systemtest.utils;
[40/50] [abbrv] incubator-hivemall git commit: Mod README
Posted by my...@apache.org.
Mod README
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/7447dde6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/7447dde6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/7447dde6
Branch: refs/heads/JIRA-22/pr-336
Commit: 7447dde61f3a9cb8e3ba5ab278a260d0a0615524
Parents: 144cb50
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 03:23:46 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 03:23:46 2016 +0900
----------------------------------------------------------------------
systemtest/README.md | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7447dde6/systemtest/README.md
----------------------------------------------------------------------
diff --git a/systemtest/README.md b/systemtest/README.md
index 2805165..2b1167e 100644
--- a/systemtest/README.md
+++ b/systemtest/README.md
@@ -195,8 +195,9 @@ pink 255 192 203
```sql
-- write your hive queries
-- comments like this and multiple queries in one row are allowed
-SELECT blue FROM color WHERE name = 'lavender';SELECT green FROM color WHERE name LIKE 'orange%'
-SELECT name FROM color WHERE blue = 255
+SELECT blue FROM color WHERE name = 'lavender';
+SELECT green FROM color WHERE name LIKE 'orange%';
+SELECT name FROM color WHERE blue = 255;
```
* `systemtest/src/test/resources/hivemall/QuickExample/answer/test3` (`systemtest/src/test/resources/${path/to/package}/${className}/answer/${fileName}`)
@@ -205,6 +206,6 @@ tsv format is required
```tsv
250
-165 69
-azure blue magenta
+165 69
+azure blue magenta
```
[37/50] [abbrv] incubator-hivemall git commit: Make dir name static
Posted by my...@apache.org.
Make dir name static
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1f3df54c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1f3df54c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1f3df54c
Branch: refs/heads/JIRA-22/pr-336
Commit: 1f3df54c0183a61390f58b94f58c12e531754a09
Parents: 33eab26
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 01:57:31 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 01:57:31 2016 +0900
----------------------------------------------------------------------
.../hivemall/systemtest/runner/SystemTestCommonInfo.java | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1f3df54c/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
index 60292fa..82b433f 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestCommonInfo.java
@@ -21,6 +21,10 @@ package hivemall.systemtest.runner;
import javax.annotation.Nonnull;
public class SystemTestCommonInfo {
+ private static final String CASE = "case";
+ private static final String ANSWER = "answer";
+ private static final String INIT = "init";
+
@Nonnull
public final String baseDir;
@Nonnull
@@ -34,9 +38,9 @@ public class SystemTestCommonInfo {
public SystemTestCommonInfo(@Nonnull final Class<?> clazz) {
baseDir = clazz.getName().replace(".", "/");
- caseDir = baseDir + "/case/";
- answerDir = baseDir + "/answer/";
- initDir = baseDir + "/init/";
+ caseDir = baseDir + "/" + CASE + "/";
+ answerDir = baseDir + "/" + ANSWER + "/";
+ initDir = baseDir + "/" + INIT + "/";
dbName = clazz.getName().replace(".", "_").toLowerCase();
}
}
[31/50] [abbrv] incubator-hivemall git commit: minor fix
Posted by my...@apache.org.
minor fix
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/80be81ec
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/80be81ec
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/80be81ec
Branch: refs/heads/JIRA-22/pr-385
Commit: 80be81ecf92cd4675dcdfaa5f456d84d484d6c44
Parents: 4cfa4e5
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 28 20:01:08 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 28 20:01:08 2016 +0900
----------------------------------------------------------------------
.../main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java | 2 +-
.../test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/80be81ec/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
index 507aefa..96fdc5b 100644
--- a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -335,7 +335,7 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
final double snr = Math.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i])
/ (sds[j] + sds[k]);
// if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal),
- // regard feature `i` as meaningless between class `j` and `k` and skip
+ // regard feature `i` as meaningless between class `j` and `k`, skip
if (!Double.isNaN(snr)) {
result[i] += snr; // accept `Infinity`
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/80be81ec/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 2e18280..7b62b92 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -740,7 +740,8 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
// | 1 2 3 |T | 5 6 7 |
// | 3 4 5 | * | 7 8 9 |
- val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))).toDF.as("c0", "arg0", "arg1")
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() shouldEqual
Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))
[03/50] [abbrv] incubator-hivemall git commit: add chi2 and chi2_test
Posted by my...@apache.org.
add chi2 and chi2_test
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/d3009be5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/d3009be5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/d3009be5
Branch: refs/heads/JIRA-22/pr-385
Commit: d3009be59bcf314b373038e3db8903a041396931
Parents: 6f9b4fa
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 16 16:00:58 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 16 16:00:58 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/ChiSquareTestUDF.java | 21 +++++
.../hivemall/ftvec/selection/ChiSquareUDF.java | 21 +++++
.../ftvec/selection/DissociationDegreeUDF.java | 88 ++++++++++++++++++++
.../java/hivemall/utils/math/StatsUtils.java | 49 +++++++++++
4 files changed, 179 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3009be5/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
new file mode 100644
index 0000000..d367085
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
@@ -0,0 +1,21 @@
+package hivemall.ftvec.selection;
+
+import hivemall.utils.math.StatsUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+
+import javax.annotation.Nonnull;
+
+@Description(name = "chi2_test",
+ value = "_FUNC_(array<number> expected, array<number> observed) - Returns p-value as double")
+public class ChiSquareTestUDF extends DissociationDegreeUDF {
+ @Override
+ double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) {
+ return StatsUtils.chiSquareTest(expected, observed);
+ }
+
+ @Override
+ @Nonnull
+ String getFuncName() {
+ return "chi2_test";
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3009be5/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
new file mode 100644
index 0000000..937b1bd
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -0,0 +1,21 @@
+package hivemall.ftvec.selection;
+
+import hivemall.utils.math.StatsUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+
+import javax.annotation.Nonnull;
+
+@Description(name = "chi2",
+ value = "_FUNC_(array<number> expected, array<number> observed) - Returns chi2-value as double")
+public class ChiSquareUDF extends DissociationDegreeUDF {
+ @Override
+ double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) {
+ return StatsUtils.chiSquare(expected, observed);
+ }
+
+ @Override
+ @Nonnull
+ String getFuncName() {
+ return "chi2";
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3009be5/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java b/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
new file mode 100644
index 0000000..0acae82
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
@@ -0,0 +1,88 @@
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.math.StatsUtils;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import javax.annotation.Nonnull;
+
+@Description(name = "",
+ value = "_FUNC_(array<number> expected, array<number> observed) - Returns dissociation degree as double")
+public abstract class DissociationDegreeUDF extends GenericUDF {
+ private ListObjectInspector expectedOI;
+ private DoubleObjectInspector expectedElOI;
+ private ListObjectInspector observedOI;
+ private DoubleObjectInspector observedElOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isListOI(OIs[0])
+ || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[0]).getListElementObjectInspector())){
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `expected`");
+ }
+
+ if (!HiveUtils.isListOI(OIs[1])
+ || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())){
+ throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `observed`");
+ }
+
+ expectedOI = (ListObjectInspector) OIs[0];
+ expectedElOI = (DoubleObjectInspector) expectedOI.getListElementObjectInspector();
+ observedOI = (ListObjectInspector) OIs[1];
+ observedElOI = (DoubleObjectInspector) observedOI.getListElementObjectInspector();
+
+ return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ final double[] expected = HiveUtils.asDoubleArray(dObj[0].get(),expectedOI,expectedElOI);
+ final double[] observed = HiveUtils.asDoubleArray(dObj[1].get(),observedOI,observedElOI);
+
+ Preconditions.checkNotNull(expected);
+ Preconditions.checkNotNull(observed);
+ Preconditions.checkArgument(expected.length == observed.length);
+
+ final double dissociation = calcDissociation(expected,observed);
+
+ return new DoubleWritable(dissociation);
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append(getFuncName());
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+
+ abstract double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed);
+
+ @Nonnull
+ abstract String getFuncName();
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d3009be5/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index 42a2c90..ffccea3 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -22,6 +22,7 @@ import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
@@ -189,4 +190,52 @@ public final class StatsUtils {
return 1.d - numerator / denominator;
}
+ /**
+ * @param expected mean vector whose value is expected
+ * @param observed mean vector whose value is observed
+ * @return chi2-value
+ */
+ public static double chiSquare(@Nonnull final double[] expected, @Nonnull final double[] observed) {
+ Preconditions.checkArgument(expected.length == observed.length);
+
+ double sumExpected = 0.0D;
+ double sumObserved = 0.0D;
+
+ for (int ratio = 0; ratio < observed.length; ++ratio) {
+ sumExpected += expected[ratio];
+ sumObserved += observed[ratio];
+ }
+
+ double var15 = 1.0D;
+ boolean rescale = false;
+ if (Math.abs(sumExpected - sumObserved) > 1.0E-5D) {
+ var15 = sumObserved / sumExpected;
+ rescale = true;
+ }
+
+ double sumSq = 0.0D;
+
+ for (int i = 0; i < observed.length; ++i) {
+ double dev;
+ if (rescale) {
+ dev = observed[i] - var15 * expected[i];
+ sumSq += dev * dev / (var15 * expected[i]);
+ } else {
+ dev = observed[i] - expected[i];
+ sumSq += dev * dev / expected[i];
+ }
+ }
+
+ return sumSq;
+ }
+
+ /**
+ * @param expected means vector whose value is expected
+ * @param observed means vector whose value is observed
+ * @return p-value
+ */
+ public static double chiSquareTest(@Nonnull final double[] expected,@Nonnull final double[] observed) {
+ ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.0D);
+ return 1.0D - distribution.cumulativeProbability(chiSquare(expected, observed));
+ }
}
[35/50] [abbrv] incubator-hivemall git commit: Mod README
Posted by my...@apache.org.
Mod README
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ba912677
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ba912677
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ba912677
Branch: refs/heads/JIRA-22/pr-336
Commit: ba91267796cbfdee53aaef02af882aff591fb8f7
Parents: 43ca0c8
Author: amaya <gi...@sapphire.in.net>
Authored: Thu Nov 17 14:15:03 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Thu Nov 17 14:15:03 2016 +0900
----------------------------------------------------------------------
systemtest/README.md | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ba912677/systemtest/README.md
----------------------------------------------------------------------
diff --git a/systemtest/README.md b/systemtest/README.md
index 4fca0c3..2805165 100644
--- a/systemtest/README.md
+++ b/systemtest/README.md
@@ -157,7 +157,7 @@ public class QuickExample {
public void test3() throws Exception {
// test on HiveRunner once only
// auto matching by files which name is `test3` in `case/` and `answer/`
- team.set(HQ.autoMatchingByFileName("test3", ci)); // unordered test
+ team.set(HQ.autoMatchingByFileName("test3"), ci); // unordered test
team.run(); // this call is required
}
@@ -165,7 +165,7 @@ public class QuickExample {
public void test4() throws Exception {
// test on HiveRunner once only
predictor.expect(Throwable.class); // you can use systemtest w/ other rules
- team.set(HQ.fromStatement("invalid queryyy")); // this query throws an exception
+ team.set(HQ.fromStatement("invalid queryyy"), "never used"); // this query throws an exception
team.run(); // this call is required
// thrown exception will be caught by `ExpectedException` rule
}
@@ -174,7 +174,7 @@ public class QuickExample {
The above requires following files
-* `systemtest/src/test/resources/hivemall/HogeTest/init/color.tsv` (`systemtest/src/test/resources/${path/to/package}/${className}/init/${fileName}`)
+* `systemtest/src/test/resources/hivemall/QuickExample/init/color.tsv` (`systemtest/src/test/resources/${path/to/package}/${className}/init/${fileName}`)
```tsv
blue 0 0 255
@@ -190,7 +190,7 @@ red 255 0 0
pink 255 192 203
```
-* `systemtest/src/test/resources/hivemall/HogeTest/case/test3` (`systemtest/src/test/resources/${path/to/package}/${className}/case/${fileName}`)
+* `systemtest/src/test/resources/hivemall/QuickExample/case/test3` (`systemtest/src/test/resources/${path/to/package}/${className}/case/${fileName}`)
```sql
-- write your hive queries
@@ -199,12 +199,12 @@ SELECT blue FROM color WHERE name = 'lavender';SELECT green FROM color WHERE nam
SELECT name FROM color WHERE blue = 255
```
-* `systemtest/src/test/resources/hivemall/HogeTest/answer/test3` (`systemtest/src/test/resources/${path/to/package}/${className}/answer/${fileName}`)
+* `systemtest/src/test/resources/hivemall/QuickExample/answer/test3` (`systemtest/src/test/resources/${path/to/package}/${className}/answer/${fileName}`)
tsv format is required
```tsv
-230
+250
165 69
azure blue magenta
```
[25/50] [abbrv] incubator-hivemall git commit: Merge 'master' into
'feature/feature_selection'
Posted by my...@apache.org.
Merge 'master' into 'feature/feature_selection'
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/aa7d5299
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/aa7d5299
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/aa7d5299
Branch: refs/heads/JIRA-22/pr-385
Commit: aa7d5299739349b49ef4f50cc2c1969f5cb8a78f
Parents: a1f8f95 bc8b015
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 27 16:02:02 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 27 16:02:02 2016 +0900
----------------------------------------------------------------------
README.md | 7 +-
core/pom.xml | 2 +-
.../hivemall/ensemble/ArgminKLDistanceUDAF.java | 1 +
.../main/java/hivemall/ensemble/MaxRowUDAF.java | 21 +-
.../hivemall/ensemble/MaxValueLabelUDAF.java | 1 +
.../hivemall/ensemble/bagging/VotedAvgUDAF.java | 1 +
.../ensemble/bagging/WeightVotedAvgUDAF.java | 1 +
.../main/java/hivemall/evaluation/AUCUDAF.java | 37 +-
.../evaluation/BinaryResponsesMeasures.java | 31 +-
.../java/hivemall/evaluation/FMeasureUDAF.java | 1 +
.../evaluation/GradedResponsesMeasures.java | 7 +-
.../evaluation/LogarithmicLossUDAF.java | 1 +
.../main/java/hivemall/evaluation/MAPUDAF.java | 55 +--
.../main/java/hivemall/evaluation/MRRUDAF.java | 55 +--
.../evaluation/MeanAbsoluteErrorUDAF.java | 1 +
.../evaluation/MeanSquaredErrorUDAF.java | 1 +
.../main/java/hivemall/evaluation/NDCGUDAF.java | 45 +--
.../java/hivemall/evaluation/PrecisionUDAF.java | 55 +--
.../main/java/hivemall/evaluation/R2UDAF.java | 1 +
.../java/hivemall/evaluation/RecallUDAF.java | 55 +--
.../evaluation/RootMeanSquaredErrorUDAF.java | 1 +
.../java/hivemall/fm/FMPredictGenericUDAF.java | 23 +-
.../hivemall/ftvec/binning/BuildBinsUDAF.java | 45 ++-
.../ftvec/binning/FeatureBinningUDF.java | 26 +-
.../ftvec/binning/NumericHistogram.java | 28 +-
.../ftvec/conv/ConvertToDenseModelUDAF.java | 1 +
.../hivemall/ftvec/text/TermFrequencyUDAF.java | 1 +
.../ftvec/trans/OnehotEncodingUDAF.java | 335 +++++++++++++++++++
.../smile/tools/RandomForestEnsembleUDAF.java | 1 +
.../tools/array/ArrayAvgGenericUDAF.java | 27 +-
.../java/hivemall/tools/array/ArraySumUDAF.java | 1 +
.../hivemall/tools/bits/BitsCollectUDAF.java | 23 +-
.../main/java/hivemall/tools/map/UDAFToMap.java | 23 +-
.../hivemall/tools/map/UDAFToOrderedMap.java | 6 +-
.../java/hivemall/utils/hadoop/HiveUtils.java | 9 +
.../hivemall/utils/hadoop/WritableUtils.java | 15 +
.../java/hivemall/utils/lang/Identifier.java | 38 ++-
.../hive/ql/exec/MapredContextAccessor.java | 3 +
.../ftvec/trans/TestBinarizeLabelUDTF.java | 7 +-
mixserv/pom.xml | 2 +-
nlp/pom.xml | 2 +-
.../hivemall/nlp/tokenizer/KuromojiUDFTest.java | 31 +-
pom.xml | 1 +
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.hive | 3 +
resources/ddl/define-udfs.td.hql | 1 +
.../org/apache/spark/sql/hive/HivemallOps.scala | 5 +-
.../apache/spark/sql/hive/HiveUdfSuite.scala | 36 ++
.../spark/sql/hive/HivemallOpsSuite.scala | 47 ++-
.../sql/catalyst/expressions/EachTopK.scala | 108 ++++++
.../org/apache/spark/sql/hive/HivemallOps.scala | 43 ++-
.../apache/spark/sql/hive/HiveUdfSuite.scala | 43 ++-
.../spark/sql/hive/HivemallOpsSuite.scala | 70 ++--
.../sql/hive/benchmark/MiscBenchmark.scala | 72 ++--
spark/spark-common/pom.xml | 2 +-
xgboost/pom.xml | 2 +-
56 files changed, 1125 insertions(+), 338 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --cc core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 9272e60,91f1dfa..c752188
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@@ -55,9 -55,9 +55,10 @@@ import org.apache.hadoop.hive.serde2.ob
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+ import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/resources/ddl/define-all.hive
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/aa7d5299/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
[23/50] [abbrv] incubator-hivemall git commit: Rename SSTChangePoint
-> SingularSpectrumTransform
Posted by my...@apache.org.
Rename SSTChangePoint -> SingularSpectrumTransform
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/bde06e09
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/bde06e09
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/bde06e09
Branch: refs/heads/JIRA-22/pr-356
Commit: bde06e0952445bf60a9aef4bca182c0afe87e250
Parents: 3ebd771
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Tue Sep 27 14:06:20 2016 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Tue Sep 27 14:06:20 2016 +0900
----------------------------------------------------------------------
.../java/hivemall/anomaly/SSTChangePoint.java | 118 -----------
.../hivemall/anomaly/SSTChangePointUDF.java | 197 -------------------
.../anomaly/SingularSpectrumTransform.java | 118 +++++++++++
.../anomaly/SingularSpectrumTransformUDF.java | 197 +++++++++++++++++++
.../hivemall/anomaly/SSTChangePointTest.java | 111 -----------
.../anomaly/SingularSpectrumTransformTest.java | 111 +++++++++++
6 files changed, 426 insertions(+), 426 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePoint.java b/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
deleted file mode 100644
index e693bd4..0000000
--- a/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.anomaly;
-
-import hivemall.anomaly.SSTChangePointUDF.SSTChangePointInterface;
-import hivemall.anomaly.SSTChangePointUDF.Parameters;
-import hivemall.utils.collections.DoubleRingBuffer;
-import org.apache.commons.math3.linear.MatrixUtils;
-import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.SingularValueDecomposition;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-
-import java.util.Arrays;
-
-import javax.annotation.Nonnull;
-
-final class SSTChangePoint implements SSTChangePointInterface {
-
- @Nonnull
- private final PrimitiveObjectInspector oi;
-
- @Nonnull
- private final int window;
- @Nonnull
- private final int nPastWindow;
- @Nonnull
- private final int nCurrentWindow;
- @Nonnull
- private final int pastSize;
- @Nonnull
- private final int currentSize;
- @Nonnull
- private final int currentOffset;
- @Nonnull
- private final int r;
-
- @Nonnull
- private final DoubleRingBuffer xRing;
- @Nonnull
- private final double[] xSeries;
-
- SSTChangePoint(@Nonnull Parameters params, @Nonnull PrimitiveObjectInspector oi) {
- this.oi = oi;
-
- this.window = params.w;
- this.nPastWindow = params.n;
- this.nCurrentWindow = params.m;
- this.pastSize = window + nPastWindow;
- this.currentSize = window + nCurrentWindow;
- this.currentOffset = params.g;
- this.r = params.r;
-
- // (w + n) past samples for the n-past-windows
- // (w + m) current samples for the m-current-windows, starting from offset g
- // => need to hold past (w + n + g + w + m) samples from the latest sample
- int holdSampleSize = pastSize + currentOffset + currentSize;
-
- this.xRing = new DoubleRingBuffer(holdSampleSize);
- this.xSeries = new double[holdSampleSize];
- }
-
- @Override
- public void update(@Nonnull final Object arg, @Nonnull final double[] outScores)
- throws HiveException {
- double x = PrimitiveObjectInspectorUtils.getDouble(arg, oi);
- xRing.add(x).toArray(xSeries, true /* FIFO */);
-
- // need to wait until the buffer is filled
- if (!xRing.isFull()) {
- outScores[0] = 0.d;
- } else {
- outScores[0] = computeScore();
- }
- }
-
- private double computeScore() {
- // create past trajectory matrix and find its left singular vectors
- RealMatrix H = MatrixUtils.createRealMatrix(window, nPastWindow);
- for (int i = 0; i < nPastWindow; i++) {
- H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window));
- }
- SingularValueDecomposition svdH = new SingularValueDecomposition(H);
- RealMatrix UT = svdH.getUT();
-
- // create current trajectory matrix and find its left singular vectors
- RealMatrix G = MatrixUtils.createRealMatrix(window, nCurrentWindow);
- int currentHead = pastSize + currentOffset;
- for (int i = 0; i < nCurrentWindow; i++) {
- G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window));
- }
- SingularValueDecomposition svdG = new SingularValueDecomposition(G);
- RealMatrix Q = svdG.getU();
-
- // find the largest singular value for the r principal components
- RealMatrix UTQ = UT.getSubMatrix(0, r - 1, 0, window - 1).multiply(Q.getSubMatrix(0, window - 1, 0, r - 1));
- SingularValueDecomposition svdUTQ = new SingularValueDecomposition(UTQ);
- double[] s = svdUTQ.getSingularValues();
-
- return 1.d - s[0];
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java b/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
deleted file mode 100644
index 3ab5ae8..0000000
--- a/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.anomaly;
-
-import hivemall.UDFWithOptions;
-import hivemall.utils.collections.DoubleRingBuffer;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.Preconditions;
-import hivemall.utils.lang.Primitives;
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.io.BooleanWritable;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-
-@Description(
- name = "sst_changepoint",
- value = "_FUNC_(double|array<double> x [, const string options])"
- + " - Returns change-point scores and decisions using Singular Spectrum Transformation (SST)."
- + " It will return a tuple <double changepoint_score [, boolean is_changepoint]>")
-public final class SSTChangePointUDF extends UDFWithOptions {
-
- private transient Parameters _params;
- private transient SSTChangePoint _sst;
-
- private transient double[] _scores;
- private transient Object[] _result;
- private transient DoubleWritable _changepointScore;
- @Nullable
- private transient BooleanWritable _isChangepoint = null;
-
- public SSTChangePointUDF() {}
-
- // Visible for testing
- Parameters getParameters() {
- return _params;
- }
-
- @Override
- protected Options getOptions() {
- Options opts = new Options();
- opts.addOption("w", "window", true, "Number of samples which affects change-point score [default: 30]");
- opts.addOption("n", "n_past", true,
- "Number of past windows for change-point scoring [default: equal to `w` = 30]");
- opts.addOption("m", "n_current", true,
- "Number of current windows for change-point scoring [default: equal to `w` = 30]");
- opts.addOption("g", "current_offset", true,
- "Offset of the current windows from the updating sample [default: `-w` = -30]");
- opts.addOption("r", "n_component", true,
- "Number of singular vectors (i.e. principal components) [default: 3]");
- opts.addOption("k", "n_dim", true,
- "Number of dimensions for the Krylov subspaces [default: 5 (`2*r` if `r` is even, `2*r-1` otherwise)]");
- opts.addOption("th", "threshold", true,
- "Score threshold (inclusive) for determining change-point existence [default: -1, do not output decision]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(String optionValues) throws UDFArgumentException {
- CommandLine cl = parseOptions(optionValues);
-
- this._params.w = Primitives.parseInt(cl.getOptionValue("w"), _params.w);
- this._params.n = Primitives.parseInt(cl.getOptionValue("n"), _params.w);
- this._params.m = Primitives.parseInt(cl.getOptionValue("m"), _params.w);
- this._params.g = Primitives.parseInt(cl.getOptionValue("g"), -1 * _params.w);
- this._params.r = Primitives.parseInt(cl.getOptionValue("r"), _params.r);
- this._params.k = Primitives.parseInt(
- cl.getOptionValue("k"), (_params.r % 2 == 0) ? (2 * _params.r) : (2 * _params.r - 1));
- this._params.changepointThreshold = Primitives.parseDouble(
- cl.getOptionValue("th"), _params.changepointThreshold);
-
- Preconditions.checkArgument(_params.w >= 2, "w must be greather than 1: " + _params.w);
- Preconditions.checkArgument(_params.r >= 1, "r must be greater than 0: " + _params.r);
- Preconditions.checkArgument(_params.k >= 1, "k must be greater than 0: " + _params.k);
- Preconditions.checkArgument(_params.changepointThreshold > 0.d && _params.changepointThreshold < 1.d,
- "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold);
-
- return cl;
- }
-
- @Override
- public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
- throws UDFArgumentException {
- if (argOIs.length < 1 || argOIs.length > 2) {
- throw new UDFArgumentException(
- "_FUNC_(double|array<double> x [, const string options]) takes 1 or 2 arguments: "
- + Arrays.toString(argOIs));
- }
-
- this._params = new Parameters();
- if (argOIs.length == 2) {
- String options = HiveUtils.getConstString(argOIs[1]);
- processOptions(options);
- }
-
- ObjectInspector argOI0 = argOIs[0];
- PrimitiveObjectInspector xOI = HiveUtils.asDoubleCompatibleOI(argOI0);
- this._sst = new SSTChangePoint(_params, xOI);
-
- this._scores = new double[1];
-
- final Object[] result;
- final ArrayList<String> fieldNames = new ArrayList<String>();
- final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
- fieldNames.add("changepoint_score");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
- if (_params.changepointThreshold != -1d) {
- fieldNames.add("is_changepoint");
- fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector);
- result = new Object[2];
- this._isChangepoint = new BooleanWritable(false);
- result[1] = _isChangepoint;
- } else {
- result = new Object[1];
- }
- this._changepointScore = new DoubleWritable(0.d);
- result[0] = _changepointScore;
- this._result = result;
-
- return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
- }
-
- @Override
- public Object[] evaluate(@Nonnull DeferredObject[] args) throws HiveException {
- Object x = args[0].get();
- if (x == null) {
- return _result;
- }
-
- _sst.update(x, _scores);
-
- double changepointScore = _scores[0];
- _changepointScore.set(changepointScore);
- if (_isChangepoint != null) {
- _isChangepoint.set(changepointScore >= _params.changepointThreshold);
- }
-
- return _result;
- }
-
- @Override
- public void close() throws IOException {
- this._result = null;
- this._changepointScore = null;
- this._isChangepoint = null;
- }
-
- @Override
- public String getDisplayString(String[] children) {
- return "sst(" + Arrays.toString(children) + ")";
- }
-
- static final class Parameters {
- int w = 30;
- int n = 30;
- int m = 30;
- int g = -30;
- int r = 3;
- int k = 5;
- double changepointThreshold = -1.d;
-
- Parameters() {}
- }
-
- public interface SSTChangePointInterface {
- void update(@Nonnull Object arg, @Nonnull double[] outScores) throws HiveException;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
new file mode 100644
index 0000000..c964129
--- /dev/null
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
@@ -0,0 +1,118 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.anomaly.SingularSpectrumTransformUDF.SingularSpectrumTransformInterface;
+import hivemall.anomaly.SingularSpectrumTransformUDF.Parameters;
+import hivemall.utils.collections.DoubleRingBuffer;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+final class SingularSpectrumTransform implements SingularSpectrumTransformInterface {
+
+ @Nonnull
+ private final PrimitiveObjectInspector oi;
+
+ @Nonnull
+ private final int window;
+ @Nonnull
+ private final int nPastWindow;
+ @Nonnull
+ private final int nCurrentWindow;
+ @Nonnull
+ private final int pastSize;
+ @Nonnull
+ private final int currentSize;
+ @Nonnull
+ private final int currentOffset;
+ @Nonnull
+ private final int r;
+
+ @Nonnull
+ private final DoubleRingBuffer xRing;
+ @Nonnull
+ private final double[] xSeries;
+
+ SingularSpectrumTransform(@Nonnull Parameters params, @Nonnull PrimitiveObjectInspector oi) {
+ this.oi = oi;
+
+ this.window = params.w;
+ this.nPastWindow = params.n;
+ this.nCurrentWindow = params.m;
+ this.pastSize = window + nPastWindow;
+ this.currentSize = window + nCurrentWindow;
+ this.currentOffset = params.g;
+ this.r = params.r;
+
+ // (w + n) past samples for the n-past-windows
+ // (w + m) current samples for the m-current-windows, starting from offset g
+ // => need to hold past (w + n + g + w + m) samples from the latest sample
+ int holdSampleSize = pastSize + currentOffset + currentSize;
+
+ this.xRing = new DoubleRingBuffer(holdSampleSize);
+ this.xSeries = new double[holdSampleSize];
+ }
+
+ @Override
+ public void update(@Nonnull final Object arg, @Nonnull final double[] outScores)
+ throws HiveException {
+ double x = PrimitiveObjectInspectorUtils.getDouble(arg, oi);
+ xRing.add(x).toArray(xSeries, true /* FIFO */);
+
+ // need to wait until the buffer is filled
+ if (!xRing.isFull()) {
+ outScores[0] = 0.d;
+ } else {
+ outScores[0] = computeScore();
+ }
+ }
+
+ private double computeScore() {
+ // create past trajectory matrix and find its left singular vectors
+ RealMatrix H = MatrixUtils.createRealMatrix(window, nPastWindow);
+ for (int i = 0; i < nPastWindow; i++) {
+ H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window));
+ }
+ SingularValueDecomposition svdH = new SingularValueDecomposition(H);
+ RealMatrix UT = svdH.getUT();
+
+ // create current trajectory matrix and find its left singular vectors
+ RealMatrix G = MatrixUtils.createRealMatrix(window, nCurrentWindow);
+ int currentHead = pastSize + currentOffset;
+ for (int i = 0; i < nCurrentWindow; i++) {
+ G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window));
+ }
+ SingularValueDecomposition svdG = new SingularValueDecomposition(G);
+ RealMatrix Q = svdG.getU();
+
+ // find the largest singular value for the r principal components
+ RealMatrix UTQ = UT.getSubMatrix(0, r - 1, 0, window - 1).multiply(Q.getSubMatrix(0, window - 1, 0, r - 1));
+ SingularValueDecomposition svdUTQ = new SingularValueDecomposition(UTQ);
+ double[] s = svdUTQ.getSingularValues();
+
+ return 1.d - s[0];
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
new file mode 100644
index 0000000..2ec0a91
--- /dev/null
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
@@ -0,0 +1,197 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.UDFWithOptions;
+import hivemall.utils.collections.DoubleRingBuffer;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.BooleanWritable;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+@Description(
+ name = "sst",
+ value = "_FUNC_(double|array<double> x [, const string options])"
+ + " - Returns change-point scores and decisions using Singular Spectrum Transformation (SST)."
+ + " It will return a tuple <double changepoint_score [, boolean is_changepoint]>")
+public final class SingularSpectrumTransformUDF extends UDFWithOptions {
+
+ private transient Parameters _params;
+ private transient SingularSpectrumTransform _sst;
+
+ private transient double[] _scores;
+ private transient Object[] _result;
+ private transient DoubleWritable _changepointScore;
+ @Nullable
+ private transient BooleanWritable _isChangepoint = null;
+
+ public SingularSpectrumTransformUDF() {}
+
+ // Visible for testing
+ Parameters getParameters() {
+ return _params;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("w", "window", true, "Number of samples which affects change-point score [default: 30]");
+ opts.addOption("n", "n_past", true,
+ "Number of past windows for change-point scoring [default: equal to `w` = 30]");
+ opts.addOption("m", "n_current", true,
+ "Number of current windows for change-point scoring [default: equal to `w` = 30]");
+ opts.addOption("g", "current_offset", true,
+ "Offset of the current windows from the updating sample [default: `-w` = -30]");
+ opts.addOption("r", "n_component", true,
+ "Number of singular vectors (i.e. principal components) [default: 3]");
+ opts.addOption("k", "n_dim", true,
+ "Number of dimensions for the Krylov subspaces [default: 5 (`2*r` if `r` is even, `2*r-1` otherwise)]");
+ opts.addOption("th", "threshold", true,
+ "Score threshold (inclusive) for determining change-point existence [default: -1, do not output decision]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(String optionValues) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValues);
+
+ this._params.w = Primitives.parseInt(cl.getOptionValue("w"), _params.w);
+ this._params.n = Primitives.parseInt(cl.getOptionValue("n"), _params.w);
+ this._params.m = Primitives.parseInt(cl.getOptionValue("m"), _params.w);
+ this._params.g = Primitives.parseInt(cl.getOptionValue("g"), -1 * _params.w);
+ this._params.r = Primitives.parseInt(cl.getOptionValue("r"), _params.r);
+ this._params.k = Primitives.parseInt(
+ cl.getOptionValue("k"), (_params.r % 2 == 0) ? (2 * _params.r) : (2 * _params.r - 1));
+ this._params.changepointThreshold = Primitives.parseDouble(
+ cl.getOptionValue("th"), _params.changepointThreshold);
+
+ Preconditions.checkArgument(_params.w >= 2, "w must be greather than 1: " + _params.w);
+ Preconditions.checkArgument(_params.r >= 1, "r must be greater than 0: " + _params.r);
+ Preconditions.checkArgument(_params.k >= 1, "k must be greater than 0: " + _params.k);
+ Preconditions.checkArgument(_params.changepointThreshold > 0.d && _params.changepointThreshold < 1.d,
+ "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold);
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ if (argOIs.length < 1 || argOIs.length > 2) {
+ throw new UDFArgumentException(
+ "_FUNC_(double|array<double> x [, const string options]) takes 1 or 2 arguments: "
+ + Arrays.toString(argOIs));
+ }
+
+ this._params = new Parameters();
+ if (argOIs.length == 2) {
+ String options = HiveUtils.getConstString(argOIs[1]);
+ processOptions(options);
+ }
+
+ ObjectInspector argOI0 = argOIs[0];
+ PrimitiveObjectInspector xOI = HiveUtils.asDoubleCompatibleOI(argOI0);
+ this._sst = new SingularSpectrumTransform(_params, xOI);
+
+ this._scores = new double[1];
+
+ final Object[] result;
+ final ArrayList<String> fieldNames = new ArrayList<String>();
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("changepoint_score");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ if (_params.changepointThreshold != -1d) {
+ fieldNames.add("is_changepoint");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector);
+ result = new Object[2];
+ this._isChangepoint = new BooleanWritable(false);
+ result[1] = _isChangepoint;
+ } else {
+ result = new Object[1];
+ }
+ this._changepointScore = new DoubleWritable(0.d);
+ result[0] = _changepointScore;
+ this._result = result;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public Object[] evaluate(@Nonnull DeferredObject[] args) throws HiveException {
+ Object x = args[0].get();
+ if (x == null) {
+ return _result;
+ }
+
+ _sst.update(x, _scores);
+
+ double changepointScore = _scores[0];
+ _changepointScore.set(changepointScore);
+ if (_isChangepoint != null) {
+ _isChangepoint.set(changepointScore >= _params.changepointThreshold);
+ }
+
+ return _result;
+ }
+
+ @Override
+ public void close() throws IOException {
+ this._result = null;
+ this._changepointScore = null;
+ this._isChangepoint = null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "sst(" + Arrays.toString(children) + ")";
+ }
+
+ static final class Parameters {
+ int w = 30;
+ int n = 30;
+ int m = 30;
+ int g = -30;
+ int r = 3;
+ int k = 5;
+ double changepointThreshold = -1.d;
+
+ Parameters() {}
+ }
+
+ public interface SingularSpectrumTransformInterface {
+ void update(@Nonnull Object arg, @Nonnull double[] outScores) throws HiveException;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java b/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
deleted file mode 100644
index b41d474..0000000
--- a/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.anomaly;
-
-import hivemall.anomaly.SSTChangePointUDF.Parameters;
-
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.util.zip.GZIPInputStream;
-
-import javax.annotation.Nonnull;
-
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.junit.Assert;
-import org.junit.Test;
-
-public class SSTChangePointTest {
- private static final boolean DEBUG = false;
-
- @Test
- public void testSST() throws IOException, HiveException {
- Parameters params = new Parameters();
- PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
- SSTChangePoint sst = new SSTChangePoint(params, oi);
- double[] outScores = new double[1];
-
- BufferedReader reader = readFile("cf1d.csv");
- println("x change");
- String line;
- int numChangepoints = 0;
- while ((line = reader.readLine()) != null) {
- double x = Double.parseDouble(line);
- sst.update(x, outScores);
- printf("%f %f%n", x, outScores[0]);
- if (outScores[0] > 0.95d) {
- numChangepoints++;
- }
- }
- Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
- numChangepoints > 0);
- Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
- numChangepoints < 5);
- }
-
- @Test
- public void testTwitterData() throws IOException, HiveException {
- Parameters params = new Parameters();
- PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
- SSTChangePoint sst = new SSTChangePoint(params, oi);
- double[] outScores = new double[1];
-
- BufferedReader reader = readFile("twitter.csv.gz");
- println("# time x change");
- String line;
- int i = 1, numChangepoints = 0;
- while ((line = reader.readLine()) != null) {
- double x = Double.parseDouble(line);
- sst.update(x, outScores);
- printf("%d %f %f%n", i, x, outScores[0]);
- if (outScores[0] > 0.005d) {
- numChangepoints++;
- }
- i++;
- }
- Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
- numChangepoints > 0);
- Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
- numChangepoints < 5);
- }
-
- private static void println(String msg) {
- if (DEBUG) {
- System.out.println(msg);
- }
- }
-
- private static void printf(String format, Object... args) {
- if (DEBUG) {
- System.out.printf(format, args);
- }
- }
-
- @Nonnull
- private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
- InputStream is = SSTChangePointTest.class.getResourceAsStream(fileName);
- if (fileName.endsWith(".gz")) {
- is = new GZIPInputStream(is);
- }
- return new BufferedReader(new InputStreamReader(is));
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bde06e09/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java b/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
new file mode 100644
index 0000000..d4f119f
--- /dev/null
+++ b/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
@@ -0,0 +1,111 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.anomaly.SingularSpectrumTransformUDF.Parameters;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.zip.GZIPInputStream;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SingularSpectrumTransformTest {
+ private static final boolean DEBUG = false;
+
+ @Test
+ public void testSST() throws IOException, HiveException {
+ Parameters params = new Parameters();
+ PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ SingularSpectrumTransform sst = new SingularSpectrumTransform(params, oi);
+ double[] outScores = new double[1];
+
+ BufferedReader reader = readFile("cf1d.csv");
+ println("x change");
+ String line;
+ int numChangepoints = 0;
+ while ((line = reader.readLine()) != null) {
+ double x = Double.parseDouble(line);
+ sst.update(x, outScores);
+ printf("%f %f%n", x, outScores[0]);
+ if (outScores[0] > 0.95d) {
+ numChangepoints++;
+ }
+ }
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ @Test
+ public void testTwitterData() throws IOException, HiveException {
+ Parameters params = new Parameters();
+ PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ SingularSpectrumTransform sst = new SingularSpectrumTransform(params, oi);
+ double[] outScores = new double[1];
+
+ BufferedReader reader = readFile("twitter.csv.gz");
+ println("# time x change");
+ String line;
+ int i = 1, numChangepoints = 0;
+ while ((line = reader.readLine()) != null) {
+ double x = Double.parseDouble(line);
+ sst.update(x, outScores);
+ printf("%d %f %f%n", i, x, outScores[0]);
+ if (outScores[0] > 0.005d) {
+ numChangepoints++;
+ }
+ i++;
+ }
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ private static void println(String msg) {
+ if (DEBUG) {
+ System.out.println(msg);
+ }
+ }
+
+ private static void printf(String format, Object... args) {
+ if (DEBUG) {
+ System.out.printf(format, args);
+ }
+ }
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ InputStream is = SingularSpectrumTransformTest.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
+
+}
[33/50] [abbrv] incubator-hivemall git commit: change method of
testing for spark
Posted by my...@apache.org.
change method of testing for spark
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ce4a4898
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ce4a4898
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ce4a4898
Branch: refs/heads/JIRA-22/pr-385
Commit: ce4a48980e33b9f16c74a62fcea6878f28b9c08b
Parents: 8d9f0d4
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 30 17:05:20 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 30 17:05:20 2016 +0900
----------------------------------------------------------------------
.../spark/sql/hive/HivemallOpsSuite.scala | 23 ++++++++++----------
.../spark/sql/hive/HivemallOpsSuite.scala | 17 ++++++---------
2 files changed, 18 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce4a4898/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index cce22ce..c7016c0 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.{Column, Row}
import org.apache.spark.test.HivemallQueryTest
import org.apache.spark.test.TestDoubleWrapper._
import org.apache.spark.test.TestUtils._
-import org.scalatest.Matchers._
final class HivemallOpsSuite extends HivemallQueryTest {
@@ -189,7 +188,6 @@ final class HivemallOpsSuite extends HivemallQueryTest {
test("ftvec.selection - chi2") {
import hiveContext.implicits._
- implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
// see also hivemall.ftvec.selection.ChiSquareUDFTest
val df = Seq(
@@ -204,17 +202,17 @@ final class HivemallOpsSuite extends HivemallQueryTest {
.toDF("arg0", "arg1")
val result = df.select(chi2(df("arg0"), df("arg1"))).collect
- result should have length 1
+ assert(result.length == 1)
val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
(chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
(pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
}
test("ftvec.conv - quantify") {
@@ -370,8 +368,9 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
val df = data.map(d => (d, Seq(3, 1, 2), 2)).toDF("features", "importance_list", "k")
- df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect shouldEqual
- data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble)))
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect ===
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
}
test("misc - sigmoid") {
@@ -573,7 +572,6 @@ final class HivemallOpsSuite extends HivemallQueryTest {
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
- implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
// see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
// binary class
@@ -595,7 +593,7 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
(row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
// multiple class
// +-----------------+-------+
@@ -616,7 +614,7 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
(row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
}
test("user-defined aggregators for tools.matrix") {
@@ -627,7 +625,8 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
.toDF("c0", "arg0", "arg1")
- df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() shouldEqual
- Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() ===
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ce4a4898/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index fe73a1b..8446677 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{AnalysisException, Column, Row, functions}
import org.apache.spark.test.TestDoubleWrapper._
import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
-import org.scalatest.Matchers._
final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
@@ -188,7 +187,6 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
test("ftvec.selection - chi2") {
import hiveContext.implicits._
- implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
// see also hivemall.ftvec.selection.ChiSquareUDFTest
val df = Seq(
@@ -203,17 +201,17 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
.toDF("arg0", "arg1")
val result = df.select(chi2(df("arg0"), df("arg1"))).collect
- result should have length 1
+ assert(result.length == 1)
val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
(chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
(pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
}
test("ftvec.conv - quantify") {
@@ -393,8 +391,8 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
val df = data.map(d => (d, Seq(3, 1, 2), 2)).toDF("features", "importance_list", "k")
- df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect shouldEqual
- data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble)))
+ checkAnswer(df.select(select_k_best(df("features"), df("importance_list"), df("k"))),
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
}
test("misc - sigmoid") {
@@ -689,7 +687,6 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
- implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
// see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
// binary class
@@ -711,7 +708,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
(row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
// multiple class
// +-----------------+-------+
@@ -732,7 +729,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
(row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
.zipped
- .foreach((actual, expected) => actual shouldEqual expected)
+ .foreach((actual, expected) => assert(actual ~== expected))
}
test("user-defined aggregators for tools.matrix") {
[10/50] [abbrv] incubator-hivemall git commit: add ddl definitions
Posted by my...@apache.org.
add ddl definitions
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/be1ea37a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/be1ea37a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/be1ea37a
Branch: refs/heads/JIRA-22/pr-385
Commit: be1ea37a0f5048cde4284107c04e109f0f526b42
Parents: ad81b3a
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 18:00:49 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:38:01 2016 +0900
----------------------------------------------------------------------
resources/ddl/define-all-as-permanent.hive | 20 ++++++++++++++++++++
resources/ddl/define-all.hive | 20 ++++++++++++++++++++
resources/ddl/define-all.spark | 20 ++++++++++++++++++++
resources/ddl/define-udfs.td.hql | 4 ++++
4 files changed, 64 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/be1ea37a/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index bab5a29..52b73a0 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -202,6 +202,13 @@ CREATE FUNCTION zscore as 'hivemall.ftvec.scaling.ZScoreUDF' USING JAR '${hivema
DROP FUNCTION IF EXISTS l2_normalize;
CREATE FUNCTION l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF' USING JAR '${hivemall_jar}';
+-------------------------
+-- selection functions --
+-------------------------
+
+DROP FUNCTION IF EXISTS chi_square;
+CREATE FUNCTION chi_square as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
+
--------------------
-- misc functions --
--------------------
@@ -364,6 +371,9 @@ CREATE FUNCTION subarray_endwith as 'hivemall.tools.array.SubarrayEndWithUDF' US
DROP FUNCTION IF EXISTS subarray_startwith;
CREATE FUNCTION subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS subarray_by_indices;
+CREATE FUNCTION subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF' USING JAR '${hivemall_jar}';
+
DROP FUNCTION IF EXISTS array_concat;
CREATE FUNCTION array_concat as 'hivemall.tools.array.ArrayConcatUDF' USING JAR '${hivemall_jar}';
@@ -380,6 +390,9 @@ CREATE FUNCTION array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF' USING JA
DROP FUNCTION IF EXISTS array_sum;
CREATE FUNCTION array_sum as 'hivemall.tools.array.ArraySumUDAF' USING JAR '${hivemall_jar}';
+DROP FUNCTION array_top_k_indices;
+CREATE FUNCTION array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF' USING JAR '${hivemall_jar}';
+
DROP FUNCTION IF EXISTS to_string_array;
CREATE FUNCTION to_string_array as 'hivemall.tools.array.ToStringArrayUDF' USING JAR '${hivemall_jar}';
@@ -436,6 +449,13 @@ DROP FUNCTION IF EXISTS sigmoid;
CREATE FUNCTION sigmoid as 'hivemall.tools.math.SigmoidGenericUDF' USING JAR '${hivemall_jar}';
----------------------
+-- Matrix functions --
+----------------------
+
+DROP FUNCTION IF EXISTS transpose_and_dot;
+CREATE FUNCTION transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF' USING JAR '${hivemall_jar}';
+
+----------------------
-- mapred functions --
----------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/be1ea37a/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 315b4d2..a70ae0f 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -198,6 +198,13 @@ create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
drop temporary function l2_normalize;
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
+-------------------------
+-- selection functions --
+-------------------------
+
+drop temporary function chi_square;
+create temporary function chi_square as 'hivemall.ftvec.selection.ChiSquareUDF';
+
-----------------------------------
-- Feature engineering functions --
-----------------------------------
@@ -360,6 +367,9 @@ create temporary function subarray_endwith as 'hivemall.tools.array.SubarrayEndW
drop temporary function subarray_startwith;
create temporary function subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF';
+drop temporary function subarray_by_indices;
+create temporary function subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF';
+
drop temporary function array_concat;
create temporary function array_concat as 'hivemall.tools.array.ArrayConcatUDF';
@@ -376,6 +386,9 @@ create temporary function array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF
drop temporary function array_sum;
create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
+drop temporary function array_top_k_indices;
+create temporary function array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF';
+
drop temporary function to_string_array;
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';
@@ -432,6 +445,13 @@ drop temporary function sigmoid;
create temporary function sigmoid as 'hivemall.tools.math.SigmoidGenericUDF';
----------------------
+-- Matrix functions --
+----------------------
+
+drop temporary function transpose_and_dot;
+create temporary function transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF';
+
+----------------------
-- mapred functions --
----------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/be1ea37a/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 4aed65b..e009511 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -184,6 +184,13 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS normalize")
sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 'hivemall.ftvec.scaling.L2NormalizationUDF'")
/**
+ * selection functions
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi_square")
+sqlContext.sql("CREATE TEMPORARY FUNCTION chi_square AS 'hivemall.ftvec.selection.ChiSquareUDF'")
+
+/**
* misc functions
*/
@@ -309,6 +316,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_endwith AS 'hivemall.tools.ar
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS subarray_startwith")
sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_startwith AS 'hivemall.tools.array.SubarrayStartWithUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS subarray_by_indices")
+sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_by_indices AS 'hivemall.tools.array.SubarrayByIndicesUDF'")
+
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS collect_all")
sqlContext.sql("CREATE TEMPORARY FUNCTION collect_all AS 'hivemall.tools.array.CollectAllUDAF'")
@@ -321,6 +331,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION subarray AS 'hivemall.tools.array.Suba
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_avg")
sqlContext.sql("CREATE TEMPORARY FUNCTION array_avg AS 'hivemall.tools.array.ArrayAvgGenericUDAF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_top_k_indices")
+sqlContext.sql("CREATE TEMPORARY FUNCTION array_top_k_indices AS 'hivemall.tools.array.ArrayTopKIndicesUDF'")
+
/**
* compression functions
*/
@@ -355,6 +368,13 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS sigmoid")
sqlContext.sql("CREATE TEMPORARY FUNCTION sigmoid AS 'hivemall.tools.math.SigmoidGenericUDF'")
/**
+ * Matrix functions
+ */
+
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS transpose_and_dot")
+sqlContext.sql("CREATE TEMPORARY FUNCTION transpose_and_dot AS 'hivemall.tools.matrix.TransposeAndDotUDAF'")
+
+/**
* mapred functions
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/be1ea37a/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index 18500aa..92e4003 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -50,6 +50,7 @@ create temporary function powered_features as 'hivemall.ftvec.pairing.PoweredFea
create temporary function rescale as 'hivemall.ftvec.scaling.RescaleUDF';
create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
+create temporary function chi_square as 'hivemall.ftvec.selection.ChiSquareUDF';
create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifierUDTF';
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';
@@ -94,10 +95,12 @@ create temporary function array_remove as 'hivemall.tools.array.ArrayRemoveUDF';
create temporary function sort_and_uniq_array as 'hivemall.tools.array.SortAndUniqArrayUDF';
create temporary function subarray_endwith as 'hivemall.tools.array.SubarrayEndWithUDF';
create temporary function subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF';
+create temporary function subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF';
create temporary function array_concat as 'hivemall.tools.array.ArrayConcatUDF';
create temporary function subarray as 'hivemall.tools.array.SubarrayUDF';
create temporary function array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF';
create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
+create temporary function array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF';
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';
create temporary function array_intersect as 'hivemall.tools.array.ArrayIntersectUDF';
create temporary function bits_collect as 'hivemall.tools.bits.BitsCollectUDAF';
@@ -111,6 +114,7 @@ create temporary function map_tail_n as 'hivemall.tools.map.MapTailNUDF';
create temporary function to_map as 'hivemall.tools.map.UDAFToMap';
create temporary function to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap';
create temporary function sigmoid as 'hivemall.tools.math.SigmoidGenericUDF';
+create temporary function transpose_and_dot as 'hivemall.tools.matrix.TransposeAndDotUDAF';
create temporary function taskid as 'hivemall.tools.mapred.TaskIdUDF';
create temporary function jobid as 'hivemall.tools.mapred.JobIdUDF';
create temporary function rowid as 'hivemall.tools.mapred.RowIdUDF';
[02/50] [abbrv] incubator-hivemall git commit: add transpose_and_dot
Posted by my...@apache.org.
add transpose_and_dot
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6f9b4fa0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6f9b4fa0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6f9b4fa0
Branch: refs/heads/JIRA-22/pr-385
Commit: 6f9b4fa0acebf604882240ccd5507d9df45bab2d
Parents: 56adf2d
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Sep 16 15:52:54 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Sep 16 15:52:54 2016 +0900
----------------------------------------------------------------------
.../tools/matrix/TransposeAndDotUDAF.java | 191 +++++++++++++++++++
1 file changed, 191 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6f9b4fa0/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
new file mode 100644
index 0000000..4fa5ce4
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -0,0 +1,191 @@
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "transpose_and_dot",
+ value = "_FUNC_(array<number> matrix0_row, array<number> matrix1_row)" +
+ " - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)")
+public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException {
+ ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `matrix0_row`");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `matrix1_row`");
+ }
+
+ return new TransposeAndDotUDAFEvaluator();
+ }
+
+ private static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector matrix0RowOI;
+ private PrimitiveObjectInspector matrix0ElOI;
+ private ListObjectInspector matrix1RowOI;
+ private PrimitiveObjectInspector matrix1ElOI;
+
+ // PARTIAL2 and FINAL
+ private ListObjectInspector aggMatrixOI;
+ private ListObjectInspector aggMatrixRowOI;
+ private DoubleObjectInspector aggMatrixElOI;
+
+ private double[] matrix0Row;
+ private double[] matrix1Row;
+
+ @AggregationType(estimable = true)
+ static class TransposeAndDotAggregationBuffer extends AbstractAggregationBuffer {
+ double[][] aggMatrix;
+
+ @Override
+ public int estimate() {
+ return aggMatrix != null
+ ? aggMatrix.length * aggMatrix[0].length * 8
+ : 0;
+ }
+
+ public void init(int n, int m) {
+ aggMatrix = new double[n][m];
+ }
+
+ public void reset() {
+ if (aggMatrix != null) {
+ for (double[] row : aggMatrix) {
+ Arrays.fill(row, 0.0);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+ matrix0RowOI = HiveUtils.asListOI( OIs[0]);
+ matrix0ElOI = HiveUtils.asDoubleCompatibleOI(matrix0RowOI.getListElementObjectInspector());
+ matrix1RowOI = HiveUtils.asListOI(OIs[1]);
+ matrix1ElOI = HiveUtils.asDoubleCompatibleOI(matrix1RowOI.getListElementObjectInspector());
+ } else {
+ aggMatrixOI = HiveUtils.asListOI( OIs[0]);
+ aggMatrixRowOI = HiveUtils.asListOI(aggMatrixOI.getListElementObjectInspector());
+ aggMatrixElOI = HiveUtils.asDoubleOI(aggMatrixRowOI.getListElementObjectInspector());
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ if(matrix0Row==null){
+ matrix0Row=new double[matrix0RowOI.getListLength(parameters[0])];
+ }
+ if(matrix1Row==null){
+ matrix1Row=new double[matrix1RowOI.getListLength(parameters[1])];
+ }
+
+ HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false);
+ HiveUtils.toDoubleArray(parameters[1], matrix1RowOI, matrix1ElOI, matrix1Row, false);
+
+ Preconditions.checkNotNull(matrix0Row);
+ Preconditions.checkNotNull(matrix1Row);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(matrix0Row.length, matrix1Row.length);
+ }
+
+ for (int i = 0; i < matrix0Row.length; i++) {
+ for (int j = 0; j < matrix1Row.length; j++) {
+ myAgg.aggMatrix[i][j] += matrix0Row[i] * matrix1Row[j];
+ }
+ }
+ }
+
+ @Override
+ public void merge(AggregationBuffer agg, Object other) throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ List matrix = aggMatrixOI.getList(other);
+ final int n = matrix.size();
+ final double[] row =new double[ aggMatrixRowOI.getListLength(matrix.get(0))];
+ for (int i = 0; i < n; i++) {
+ HiveUtils.toDoubleArray(matrix.get(i), aggMatrixRowOI, aggMatrixElOI,row,false);
+
+ if (myAgg.aggMatrix == null) {
+ myAgg.init(n, row.length);
+ }
+
+ for (int j = 0; j < row.length; j++) {
+ myAgg.aggMatrix[i][j] += row[j];
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ return terminate(agg);
+ }
+
+ @Override
+ public Object terminate(AggregationBuffer agg) throws HiveException {
+ TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
+ for (double[] row : myAgg.aggMatrix) {
+ result.add(WritableUtils.toWritableList(row));
+ }
+ return result;
+ }
+ }
+}
[47/50] [abbrv] incubator-hivemall git commit: Merge branch
'sst-changepoint' of https://github.com/takuti/hivemall into JIRA-22/pr-356
Posted by my...@apache.org.
Merge branch 'sst-changepoint' of https://github.com/takuti/hivemall into JIRA-22/pr-356
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/cc344351
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/cc344351
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/cc344351
Branch: refs/heads/JIRA-22/pr-356
Commit: cc34435155e86718acb49fa42208aff730bb756c
Parents: 72d6a62 998203d
Author: myui <yu...@gmail.com>
Authored: Fri Dec 2 16:55:23 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Fri Dec 2 16:55:23 2016 +0900
----------------------------------------------------------------------
.../anomaly/SingularSpectrumTransform.java | 193 +++++++++++++++
.../anomaly/SingularSpectrumTransformUDF.java | 235 +++++++++++++++++++
.../java/hivemall/utils/math/MatrixUtils.java | 203 ++++++++++++++++
.../anomaly/SingularSpectrumTransformTest.java | 146 ++++++++++++
.../hivemall/utils/math/MatrixUtilsTest.java | 67 ++++++
5 files changed, 844 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cc344351/core/src/main/java/hivemall/utils/math/MatrixUtils.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cc344351/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java
----------------------------------------------------------------------
[13/50] [abbrv] incubator-hivemall git commit: Revert some
modifications
Posted by my...@apache.org.
Revert some modifications
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3620eb89
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3620eb89
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3620eb89
Branch: refs/heads/JIRA-22/pr-285
Commit: 3620eb89993db22ce8aee924d3cc0df33a5f9618
Parents: f81948c
Author: Takeshi YAMAMURO <li...@gmail.com>
Authored: Wed Sep 21 01:52:22 2016 +0900
Committer: Takeshi YAMAMURO <li...@gmail.com>
Committed: Wed Sep 21 01:55:59 2016 +0900
----------------------------------------------------------------------
.../src/main/java/hivemall/LearnerBaseUDTF.java | 33 ++
.../hivemall/classifier/AROWClassifierUDTF.java | 2 +-
.../hivemall/classifier/AdaGradRDAUDTF.java | 125 +++++++-
.../classifier/BinaryOnlineClassifierUDTF.java | 10 +
.../classifier/GeneralClassifierUDTF.java | 1 +
.../classifier/PassiveAggressiveUDTF.java | 2 +-
.../main/java/hivemall/model/DenseModel.java | 86 ++++-
.../main/java/hivemall/model/NewDenseModel.java | 293 +++++++++++++++++
.../model/NewSpaceEfficientDenseModel.java | 317 +++++++++++++++++++
.../java/hivemall/model/NewSparseModel.java | 197 ++++++++++++
.../java/hivemall/model/PredictionModel.java | 3 +
.../model/SpaceEfficientDenseModel.java | 92 +++++-
.../main/java/hivemall/model/SparseModel.java | 19 +-
.../model/SynchronizedModelWrapper.java | 6 +
.../hivemall/regression/AROWRegressionUDTF.java | 2 +-
.../java/hivemall/regression/AdaDeltaUDTF.java | 118 ++++++-
.../java/hivemall/regression/AdaGradUDTF.java | 119 ++++++-
.../regression/GeneralRegressionUDTF.java | 1 +
.../java/hivemall/regression/LogressUDTF.java | 65 +++-
.../PassiveAggressiveRegressionUDTF.java | 2 +-
.../hivemall/regression/RegressionBaseUDTF.java | 12 +-
.../NewSpaceEfficientNewDenseModelTest.java | 60 ++++
.../model/SpaceEfficientDenseModelTest.java | 60 ----
.../java/hivemall/mix/server/MixServerTest.java | 14 +-
.../hivemall/mix/server/MixServerSuite.scala | 4 +-
25 files changed, 1512 insertions(+), 131 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/LearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java
index 7fd5190..4cf3c7f 100644
--- a/core/src/main/java/hivemall/LearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java
@@ -25,6 +25,9 @@ import hivemall.model.DenseModel;
import hivemall.model.PredictionModel;
import hivemall.model.SpaceEfficientDenseModel;
import hivemall.model.SparseModel;
+import hivemall.model.NewDenseModel;
+import hivemall.model.NewSparseModel;
+import hivemall.model.NewSpaceEfficientDenseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
@@ -199,6 +202,36 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
return model;
}
+ protected PredictionModel createNewModel(String label) {
+ PredictionModel model;
+ final boolean useCovar = useCovariance();
+ if (dense_model) {
+ if (disable_halffloat == false && model_dims > 16777216) {
+ logger.info("Build a space efficient dense model with " + model_dims
+ + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
+ model = new NewSpaceEfficientDenseModel(model_dims, useCovar);
+ } else {
+ logger.info("Build a dense model with initial with " + model_dims
+ + " initial dimensions" + (useCovar ? " w/ covariances" : ""));
+ model = new NewDenseModel(model_dims, useCovar);
+ }
+ } else {
+ int initModelSize = getInitialModelSize();
+ logger.info("Build a sparse model with initial with " + initModelSize
+ + " initial dimensions");
+ model = new NewSparseModel(initModelSize, useCovar);
+ }
+ if (mixConnectInfo != null) {
+ model.configureClock();
+ model = new SynchronizedModelWrapper(model);
+ MixClient client = configureMixClient(mixConnectInfo, label, model);
+ model.configureMix(client, mixCancel);
+ this.mixClient = client;
+ }
+ assert (model != null);
+ return model;
+ }
+
// If a model implements a optimizer, it must override this
protected Map<String, String> getOptimzierOptions() {
return null;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
index ac8afcb..b42ab05 100644
--- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
@@ -18,11 +18,11 @@
*/
package hivemall.classifier;
-import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
index a6714f4..b512a34 100644
--- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
@@ -18,13 +18,128 @@
*/
package hivemall.classifier;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF2;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+/**
+ * @deprecated Use {@link hivemall.classifier.GeneralClassifierUDTF} instead
+ */
@Deprecated
-public final class AdaGradRDAUDTF extends GeneralClassifierUDTF {
+@Description(name = "train_adagrad_rda",
+ value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])"
+ + " - Returns a relation consists of <string|int|bigint feature, float weight>",
+ extended = "Build a prediction model by Adagrad+RDA regularization binary classifier")
+public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF {
+
+ private float eta;
+ private float lambda;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, false, true);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]");
+ opts.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [default: 100]");
+ return opts;
+ }
- public AdaGradRDAUDTF() {
- optimizerOptions.put("optimizer", "AdaGrad");
- optimizerOptions.put("regularization", "RDA");
- optimizerOptions.put("lambda", "1e-6");
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.eta = 0.1f;
+ this.lambda = 1E-6f;
+ this.scaling = 100f;
+ } else {
+ this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.1f);
+ this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 1E-6f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
}
+ @Override
+ protected void train(@Nonnull final FeatureValue[] features, final int label) {
+ final float y = label > 0 ? 1.f : -1.f;
+
+ float p = predict(features);
+ float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p
+ if (loss <= 0.f) { // max(0, 1 - y * p)
+ return;
+ }
+ // subgradient => -y * W dot xi
+ update(features, y, count);
+ }
+
+ protected void update(@Nonnull final FeatureValue[] features, final float y, final int t) {
+ for (FeatureValue f : features) {// w[f] += y * x[f]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ updateWeight(x, xi, y, t);
+ }
+ }
+
+ protected void updateWeight(@Nonnull final Object x, final float xi, final float y,
+ final float t) {
+ final float gradient = -y * xi;
+ final float scaled_gradient = gradient * scaling;
+
+ float scaled_sum_sqgrad = 0.f;
+ float scaled_sum_grad = 0.f;
+ IWeightValue old = model.get(x);
+ if (old != null) {
+ scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ scaled_sum_grad = old.getSumOfGradients();
+ }
+ scaled_sum_grad += scaled_gradient;
+ scaled_sum_sqgrad += (scaled_gradient * scaled_gradient);
+
+ float sum_grad = scaled_sum_grad * scaling;
+ double sum_sqgrad = scaled_sum_sqgrad * scaling;
+
+ // sign(u_{t,i})
+ float sign = (sum_grad > 0.f) ? 1.f : -1.f;
+ // |u_{t,i}|/t - \lambda
+ float meansOfGradients = sign * sum_grad / t - lambda;
+ if (meansOfGradients < 0.f) {
+ // x_{t,i} = 0
+ model.delete(x);
+ } else {
+ // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda)
+ float weight = -1.f * sign * eta * t * meansOfGradients / (float) Math.sqrt(sum_sqgrad);
+ IWeightValue new_w = new WeightValueParamsF2(weight, scaled_sum_sqgrad, scaled_sum_grad);
+ model.set(x, new_w);
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index 0ee5d5f..efeeb9d 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -60,6 +60,16 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
protected Optimizer optimizerImpl;
protected int count;
+ private boolean enableNewModel;
+
+ public BinaryOnlineClassifierUDTF() {
+ this.enableNewModel = false;
+ }
+
+ public BinaryOnlineClassifierUDTF(boolean enableNewModel) {
+ this.enableNewModel = enableNewModel;
+ }
+
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
index feebadd..12bd481 100644
--- a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
@@ -39,6 +39,7 @@ public class GeneralClassifierUDTF extends BinaryOnlineClassifierUDTF {
protected final Map<String, String> optimizerOptions;
public GeneralClassifierUDTF() {
+ super(true); // This enables new model interfaces
this.optimizerOptions = new HashMap<String, String>();
// Set default values
optimizerOptions.put("optimizer", "adagrad");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
index 9e404cd..191a7b5 100644
--- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
+++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
@@ -18,9 +18,9 @@
*/
package hivemall.classifier;
-import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/DenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java
index 6956875..f142cc1 100644
--- a/core/src/main/java/hivemall/model/DenseModel.java
+++ b/core/src/main/java/hivemall/model/DenseModel.java
@@ -18,18 +18,21 @@
*/
package hivemall.model;
-import java.util.Arrays;
-import javax.annotation.Nonnull;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.model.WeightValue.WeightValueParamsF2;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Copyable;
import hivemall.utils.math.MathUtils;
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
public final class DenseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(DenseModel.class);
@@ -37,6 +40,13 @@ public final class DenseModel extends AbstractPredictionModel {
private float[] weights;
private float[] covars;
+ // optional values for adagrad
+ private float[] sum_of_squared_gradients;
+ // optional value for adadelta
+ private float[] sum_of_squared_delta_x;
+ // optional value for adagrad+rda
+ private float[] sum_of_gradients;
+
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -57,6 +67,9 @@ public final class DenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
+ this.sum_of_squared_gradients = null;
+ this.sum_of_squared_delta_x = null;
+ this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -72,6 +85,20 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ if (sum_of_squared_gradients) {
+ this.sum_of_squared_gradients = new float[size];
+ }
+ if (sum_of_squared_delta_x) {
+ this.sum_of_squared_delta_x = new float[size];
+ }
+ if (sum_of_gradients) {
+ this.sum_of_gradients = new float[size];
+ }
+ }
+
+ @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -102,7 +129,16 @@ public final class DenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, 1.f);
}
- if(clocks != null) {
+ if (sum_of_squared_gradients != null) {
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ }
+ if (sum_of_squared_delta_x != null) {
+ this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
+ }
+ if (sum_of_gradients != null) {
+ this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
+ }
+ if (clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -116,7 +152,17 @@ public final class DenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
- if(covars != null) {
+ if (sum_of_squared_gradients != null) {
+ if (sum_of_squared_delta_x != null) {
+ return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
+ sum_of_squared_delta_x[i]);
+ } else if (sum_of_gradients != null) {
+ return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
+ sum_of_gradients[i]);
+ } else {
+ return (T) new WeightValueParamsF1(weights[i], sum_of_squared_gradients[i]);
+ }
+ } else if (covars != null) {
return (T) new WeightValueWithCovar(weights[i], covars[i]);
} else {
return (T) new WeightValue(weights[i]);
@@ -135,6 +181,15 @@ public final class DenseModel extends AbstractPredictionModel {
covar = value.getCovariance();
covars[i] = covar;
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = value.getSumOfGradients();
+ }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -158,6 +213,15 @@ public final class DenseModel extends AbstractPredictionModel {
if (covars != null) {
covars[i] = 1.f;
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = 0.f;
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = 0.f;
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = 0.f;
+ }
// avoid clock/delta
}
@@ -171,10 +235,8 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- int i = HiveUtils.parseInt(feature);
- ensureCapacity(i);
- weights[i] = value;
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewDenseModel.java b/core/src/main/java/hivemall/model/NewDenseModel.java
new file mode 100644
index 0000000..920794c
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewDenseModel.java
@@ -0,0 +1,293 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import java.util.Arrays;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.MathUtils;
+
+public final class NewDenseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewDenseModel.class);
+
+ private int size;
+ private float[] weights;
+ private float[] covars;
+
+ // optional value for MIX
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
+ public NewDenseModel(int ndims) {
+ this(ndims, false);
+ }
+
+ public NewDenseModel(int ndims, boolean withCovar) {
+ super();
+ int size = ndims + 1;
+ this.size = size;
+ this.weights = new float[size];
+ if (withCovar) {
+ float[] covars = new float[size];
+ Arrays.fill(covars, 1f);
+ this.covars = covars;
+ } else {
+ this.covars = null;
+ }
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return true;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ if (clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
+ }
+
+ private void ensureCapacity(final int index) {
+ if (index >= size) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ int oldSize = size;
+ logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
+ + bits + " bits)");
+ this.size = newSize;
+ this.weights = Arrays.copyOf(weights, newSize);
+ if (covars != null) {
+ this.covars = Arrays.copyOf(covars, newSize);
+ Arrays.fill(covars, oldSize, newSize, 1.f);
+ }
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return null;
+ }
+ if(covars != null) {
+ return (T) new WeightValueWithCovar(weights[i], covars[i]);
+ } else {
+ return (T) new WeightValue(weights[i]);
+ }
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(Object feature, T value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ float weight = value.get();
+ weights[i] = weight;
+ float covar = 1.f;
+ boolean hasCovar = value.hasCovariance();
+ if (hasCovar) {
+ covar = value.getCovariance();
+ covars[i] = covar;
+ }
+ short clock = 0;
+ int delta = 0;
+ if (clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta, hasCovar);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return;
+ }
+ weights[i] = 0.f;
+ if (covars != null) {
+ covars[i] = 1.f;
+ }
+ // avoid clock/delta
+ }
+
+ @Override
+ public float getWeight(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 0f;
+ }
+ return weights[i];
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weights[i] = value;
+ }
+
+ @Override
+ public float getCovariance(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 1f;
+ }
+ return covars[i];
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ weights[i] = weight;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, float covar, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ weights[i] = weight;
+ covars[i] = covar;
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public boolean contains(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return false;
+ }
+ float w = weights[i];
+ return w != 0.f;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) new Itr();
+ }
+
+ private final class Itr implements IMapIterator<Number, IWeightValue> {
+
+ private int cursor;
+ private final WeightValueWithCovar tmpWeight;
+
+ private Itr() {
+ this.cursor = -1;
+ this.tmpWeight = new WeightValueWithCovar();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < size;
+ }
+
+ @Override
+ public int next() {
+ ++cursor;
+ if (!hasNext()) {
+ return -1;
+ }
+ return cursor;
+ }
+
+ @Override
+ public Integer getKey() {
+ return cursor;
+ }
+
+ @Override
+ public IWeightValue getValue() {
+ if (covars == null) {
+ float w = weights[cursor];
+ WeightValue v = new WeightValue(w);
+ v.setTouched(w != 0f);
+ return v;
+ } else {
+ float w = weights[cursor];
+ float cov = covars[cursor];
+ WeightValueWithCovar v = new WeightValueWithCovar(w, cov);
+ v.setTouched(w != 0.f || cov != 1.f);
+ return v;
+ }
+ }
+
+ @Override
+ public <T extends Copyable<IWeightValue>> void getValue(T probe) {
+ float w = weights[cursor];
+ tmpWeight.value = w;
+ float cov = 1.f;
+ if (covars != null) {
+ cov = covars[cursor];
+ tmpWeight.setCovariance(cov);
+ }
+ tmpWeight.setTouched(w != 0.f || cov != 1.f);
+ probe.copyFrom(tmpWeight);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
new file mode 100644
index 0000000..48eb62a
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java
@@ -0,0 +1,317 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.lang.HalfFloat;
+import hivemall.utils.math.MathUtils;
+
+import java.util.Arrays;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewSpaceEfficientDenseModel.class);
+
+ private int size;
+ private short[] weights;
+ private short[] covars;
+
+ // optional value for MIX
+ private short[] clocks;
+ private byte[] deltaUpdates;
+
+ public NewSpaceEfficientDenseModel(int ndims) {
+ this(ndims, false);
+ }
+
+ public NewSpaceEfficientDenseModel(int ndims, boolean withCovar) {
+ super();
+ int size = ndims + 1;
+ this.size = size;
+ this.weights = new short[size];
+ if (withCovar) {
+ short[] covars = new short[size];
+ Arrays.fill(covars, HalfFloat.ONE);
+ this.covars = covars;
+ } else {
+ this.covars = null;
+ }
+ this.clocks = null;
+ this.deltaUpdates = null;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return true;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return covars != null;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ if (clocks == null) {
+ this.clocks = new short[size];
+ this.deltaUpdates = new byte[size];
+ }
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clocks != null;
+ }
+
+ @Override
+ public void resetDeltaUpdates(int feature) {
+ deltaUpdates[feature] = 0;
+ }
+
+ private float getWeight(final int i) {
+ final short w = weights[i];
+ return (w == HalfFloat.ZERO) ? HalfFloat.ZERO : HalfFloat.halfFloatToFloat(w);
+ }
+
+ private float getCovar(final int i) {
+ return HalfFloat.halfFloatToFloat(covars[i]);
+ }
+
+ private void _setWeight(final int i, final float v) {
+ if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
+ throw new IllegalArgumentException("Acceptable maximum weight is "
+ + HalfFloat.MAX_FLOAT + ": " + v);
+ }
+ weights[i] = HalfFloat.floatToHalfFloat(v);
+ }
+
+ private void setCovar(final int i, final float v) {
+ HalfFloat.checkRange(v);
+ covars[i] = HalfFloat.floatToHalfFloat(v);
+ }
+
+ private void ensureCapacity(final int index) {
+ if (index >= size) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ int oldSize = size;
+ logger.info("Expands internal array size from " + oldSize + " to " + newSize + " ("
+ + bits + " bits)");
+ this.size = newSize;
+ this.weights = Arrays.copyOf(weights, newSize);
+ if (covars != null) {
+ this.covars = Arrays.copyOf(covars, newSize);
+ Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
+ }
+ if(clocks != null) {
+ this.clocks = Arrays.copyOf(clocks, newSize);
+ this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return null;
+ }
+
+ if(covars != null) {
+ return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
+ } else {
+ return (T) new WeightValue(getWeight(i));
+ }
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(Object feature, T value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ float weight = value.get();
+ _setWeight(i, weight);
+ float covar = 1.f;
+ boolean hasCovar = value.hasCovariance();
+ if (hasCovar) {
+ covar = value.getCovariance();
+ setCovar(i, covar);
+ }
+ short clock = 0;
+ int delta = 0;
+ if (clocks != null && value.isTouched()) {
+ clock = (short) (clocks[i] + 1);
+ clocks[i] = clock;
+ delta = deltaUpdates[i] + 1;
+ assert (delta > 0) : delta;
+ deltaUpdates[i] = (byte) delta;
+ }
+
+ onUpdate(i, weight, covar, clock, delta, hasCovar);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ final int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return;
+ }
+ _setWeight(i, 0.f);
+ if(covars != null) {
+ setCovar(i, 1.f);
+ }
+ // avoid clock/delta
+ }
+
+ @Override
+ public float getWeight(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 0f;
+ }
+ return getWeight(i);
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ _setWeight(i, value);
+ }
+
+ @Override
+ public float getCovariance(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return 1f;
+ }
+ return getCovar(i);
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ _setWeight(i, weight);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ protected void _set(Object feature, float weight, float covar, short clock) {
+ int i = ((Integer) feature).intValue();
+ ensureCapacity(i);
+ _setWeight(i, weight);
+ setCovar(i, covar);
+ clocks[i] = clock;
+ deltaUpdates[i] = 0;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public boolean contains(Object feature) {
+ int i = HiveUtils.parseInt(feature);
+ if (i >= size) {
+ return false;
+ }
+ float w = getWeight(i);
+ return w != 0.f;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) new Itr();
+ }
+
+ private final class Itr implements IMapIterator<Number, IWeightValue> {
+
+ private int cursor;
+ private final WeightValueWithCovar tmpWeight;
+
+ private Itr() {
+ this.cursor = -1;
+ this.tmpWeight = new WeightValueWithCovar();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cursor < size;
+ }
+
+ @Override
+ public int next() {
+ ++cursor;
+ if (!hasNext()) {
+ return -1;
+ }
+ return cursor;
+ }
+
+ @Override
+ public Integer getKey() {
+ return cursor;
+ }
+
+ @Override
+ public IWeightValue getValue() {
+ if (covars == null) {
+ float w = getWeight(cursor);
+ WeightValue v = new WeightValue(w);
+ v.setTouched(w != 0f);
+ return v;
+ } else {
+ float w = getWeight(cursor);
+ float cov = getCovar(cursor);
+ WeightValueWithCovar v = new WeightValueWithCovar(w, cov);
+ v.setTouched(w != 0.f || cov != 1.f);
+ return v;
+ }
+ }
+
+ @Override
+ public <T extends Copyable<IWeightValue>> void getValue(T probe) {
+ float w = getWeight(cursor);
+ tmpWeight.value = w;
+ float cov = 1.f;
+ if (covars != null) {
+ cov = getCovar(cursor);
+ tmpWeight.setCovariance(cov);
+ }
+ tmpWeight.setTouched(w != 0.f || cov != 1.f);
+ probe.copyFrom(tmpWeight);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/NewSparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/NewSparseModel.java b/core/src/main/java/hivemall/model/NewSparseModel.java
new file mode 100644
index 0000000..4c21830
--- /dev/null
+++ b/core/src/main/java/hivemall/model/NewSparseModel.java
@@ -0,0 +1,197 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock;
+import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock;
+import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.OpenHashMap;
+
+import javax.annotation.Nonnull;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+public final class NewSparseModel extends AbstractPredictionModel {
+ private static final Log logger = LogFactory.getLog(NewSparseModel.class);
+
+ private final OpenHashMap<Object, IWeightValue> weights;
+ private final boolean hasCovar;
+ private boolean clockEnabled;
+
+ public NewSparseModel(int size) {
+ this(size, false);
+ }
+
+ public NewSparseModel(int size, boolean hasCovar) {
+ super();
+ this.weights = new OpenHashMap<Object, IWeightValue>(size);
+ this.hasCovar = hasCovar;
+ this.clockEnabled = false;
+ }
+
+ @Override
+ protected boolean isDenseModel() {
+ return false;
+ }
+
+ @Override
+ public boolean hasCovariance() {
+ return hasCovar;
+ }
+
+ @Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
+ public void configureClock() {
+ this.clockEnabled = true;
+ }
+
+ @Override
+ public boolean hasClock() {
+ return clockEnabled;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <T extends IWeightValue> T get(final Object feature) {
+ return (T) weights.get(feature);
+ }
+
+ @Override
+ public <T extends IWeightValue> void set(final Object feature, final T value) {
+ assert (feature != null);
+ assert (value != null);
+
+ final IWeightValue wrapperValue = wrapIfRequired(value);
+
+ if (clockEnabled && value.isTouched()) {
+ IWeightValue old = weights.get(feature);
+ if (old != null) {
+ short newclock = (short) (old.getClock() + (short) 1);
+ wrapperValue.setClock(newclock);
+ int newDelta = old.getDeltaUpdates() + 1;
+ wrapperValue.setDeltaUpdates((byte) newDelta);
+ }
+ }
+ weights.put(feature, wrapperValue);
+
+ onUpdate(feature, wrapperValue);
+ }
+
+ @Override
+ public void delete(@Nonnull Object feature) {
+ weights.remove(feature);
+ }
+
+ private IWeightValue wrapIfRequired(final IWeightValue value) {
+ final IWeightValue wrapper;
+ if (clockEnabled) {
+ switch (value.getType()) {
+ case NoParams:
+ wrapper = new WeightValueWithClock(value);
+ break;
+ case ParamsCovar:
+ wrapper = new WeightValueWithCovarClock(value);
+ break;
+ case ParamsF1:
+ wrapper = new WeightValueParamsF1Clock(value);
+ break;
+ case ParamsF2:
+ wrapper = new WeightValueParamsF2Clock(value);
+ break;
+ default:
+ throw new IllegalStateException("Unexpected value type: " + value.getType());
+ }
+ } else {
+ wrapper = value;
+ }
+ return wrapper;
+ }
+
+ @Override
+ public float getWeight(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 0.f : v.get();
+ }
+
+ @Override
+ public void setWeight(Object feature, float value) {
+ if(weights.containsKey(feature)) {
+ IWeightValue weight = weights.get(feature);
+ weight.set(value);
+ } else {
+ IWeightValue weight = new WeightValue(value);
+ weight.set(value);
+ weights.put(feature, weight);
+ }
+ }
+
+ @Override
+ public float getCovariance(final Object feature) {
+ IWeightValue v = weights.get(feature);
+ return v == null ? 1.f : v.getCovariance();
+ }
+
+ @Override
+ protected void _set(final Object feature, final float weight, final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if (w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found " + feature);
+ }
+ w.set(weight);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ }
+
+ @Override
+ protected void _set(final Object feature, final float weight, final float covar,
+ final short clock) {
+ final IWeightValue w = weights.get(feature);
+ if (w == null) {
+ logger.warn("Previous weight not found: " + feature);
+ throw new IllegalStateException("Previous weight not found: " + feature);
+ }
+ w.set(weight);
+ w.setCovariance(covar);
+ w.setClock(clock);
+ w.setDeltaUpdates(BYTE0);
+ }
+
+ @Override
+ public int size() {
+ return weights.size();
+ }
+
+ @Override
+ public boolean contains(final Object feature) {
+ return weights.containsKey(feature);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
+ return (IMapIterator<K, V>) weights.entries();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/PredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java
index 8d8dd2b..ea82f62 100644
--- a/core/src/main/java/hivemall/model/PredictionModel.java
+++ b/core/src/main/java/hivemall/model/PredictionModel.java
@@ -34,6 +34,9 @@ public interface PredictionModel extends MixedModel {
boolean hasCovariance();
+ void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients);
+
void configureClock();
boolean hasClock();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
index 8b668e7..caa9fea 100644
--- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
+++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
@@ -18,6 +18,8 @@
*/
package hivemall.model;
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.model.WeightValue.WeightValueParamsF2;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
@@ -26,6 +28,7 @@ import hivemall.utils.lang.HalfFloat;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
+
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
@@ -38,6 +41,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
private short[] weights;
private short[] covars;
+ // optional value for adagrad
+ private float[] sum_of_squared_gradients;
+ // optional value for adadelta
+ private float[] sum_of_squared_delta_x;
+ // optional value for adagrad+rda
+ private float[] sum_of_gradients;
+
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -58,6 +68,9 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
+ this.sum_of_squared_gradients = null;
+ this.sum_of_squared_delta_x = null;
+ this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -73,6 +86,20 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ if (sum_of_squared_gradients) {
+ this.sum_of_squared_gradients = new float[size];
+ }
+ if (sum_of_squared_delta_x) {
+ this.sum_of_squared_delta_x = new float[size];
+ }
+ if (sum_of_gradients) {
+ this.sum_of_gradients = new float[size];
+ }
+ }
+
+ @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -99,11 +126,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
return HalfFloat.halfFloatToFloat(covars[i]);
}
- private void _setWeight(final int i, final float v) {
- if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
- throw new IllegalArgumentException("Acceptable maximum weight is "
- + HalfFloat.MAX_FLOAT + ": " + v);
- }
+ private void setWeight(final int i, final float v) {
+ HalfFloat.checkRange(v);
weights[i] = HalfFloat.floatToHalfFloat(v);
}
@@ -125,7 +149,16 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
}
- if(clocks != null) {
+ if (sum_of_squared_gradients != null) {
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ }
+ if (sum_of_squared_delta_x != null) {
+ this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
+ }
+ if (sum_of_gradients != null) {
+ this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
+ }
+ if (clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -139,8 +172,17 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
-
- if(covars != null) {
+ if (sum_of_squared_gradients != null) {
+ if (sum_of_squared_delta_x != null) {
+ return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
+ sum_of_squared_delta_x[i]);
+ } else if (sum_of_gradients != null) {
+ return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
+ sum_of_gradients[i]);
+ } else {
+ return (T) new WeightValueParamsF1(getWeight(i), sum_of_squared_gradients[i]);
+ }
+ } else if (covars != null) {
return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
} else {
return (T) new WeightValue(getWeight(i));
@@ -152,13 +194,22 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
- _setWeight(i, weight);
+ setWeight(i, weight);
float covar = 1.f;
boolean hasCovar = value.hasCovariance();
if (hasCovar) {
covar = value.getCovariance();
setCovar(i, covar);
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = value.getSumOfGradients();
+ }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -178,10 +229,19 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return;
}
- _setWeight(i, 0.f);
- if(covars != null) {
+ setWeight(i, 0.f);
+ if (covars != null) {
setCovar(i, 1.f);
}
+ if (sum_of_squared_gradients != null) {
+ sum_of_squared_gradients[i] = 0.f;
+ }
+ if (sum_of_squared_delta_x != null) {
+ sum_of_squared_delta_x[i] = 0.f;
+ }
+ if (sum_of_gradients != null) {
+ sum_of_gradients[i] = 0.f;
+ }
// avoid clock/delta
}
@@ -195,10 +255,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- int i = HiveUtils.parseInt(feature);
- ensureCapacity(i);
- _setWeight(i, value);
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
@@ -214,7 +272,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
clocks[i] = clock;
deltaUpdates[i] = 0;
}
@@ -223,7 +281,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, float covar, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- _setWeight(i, weight);
+ setWeight(i, weight);
setCovar(i, covar);
clocks[i] = clock;
deltaUpdates[i] = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index bab982f..f4c4c55 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -36,10 +36,6 @@ public final class SparseModel extends AbstractPredictionModel {
private final boolean hasCovar;
private boolean clockEnabled;
- public SparseModel(int size) {
- this(size, false);
- }
-
public SparseModel(int size, boolean hasCovar) {
super();
this.weights = new OpenHashMap<Object, IWeightValue>(size);
@@ -58,6 +54,10 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {}
+
+ @Override
public void configureClock() {
this.clockEnabled = true;
}
@@ -131,15 +131,8 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
- public void setWeight(Object feature, float value) {
- if(weights.containsKey(feature)) {
- IWeightValue weight = weights.get(feature);
- weight.set(value);
- } else {
- IWeightValue weight = new WeightValue(value);
- weight.set(value);
- weights.put(feature, weight);
- }
+ public void setWeight(@Nonnull Object feature, float value) {
+ throw new UnsupportedOperationException();
}
@Override
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
index 87e89b6..dcb0bc9 100644
--- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
+++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
@@ -63,6 +63,12 @@ public final class SynchronizedModelWrapper implements PredictionModel {
}
@Override
+ public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
+ boolean sum_of_gradients) {
+ model.configureParams(sum_of_squared_gradients, sum_of_squared_delta_x, sum_of_gradients);
+ }
+
+ @Override
public void configureClock() {
model.configureClock();
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
index 0c964c8..0503145 100644
--- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
@@ -18,12 +18,12 @@
*/
package hivemall.regression;
-import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
index 50dc9b5..93453c1 100644
--- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
@@ -18,14 +18,126 @@
*/
package hivemall.regression;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF2;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
/**
* ADADELTA: AN ADAPTIVE LEARNING RATE METHOD.
+ *
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
*/
@Deprecated
-public final class AdaDeltaUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "train_adadelta_regr",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class AdaDeltaUDTF extends RegressionBaseUDTF {
+
+ private float decay;
+ private float eps;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, true, false);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
+ opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [100]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.decay = 0.95f;
+ this.eps = 1e-6f;
+ this.scaling = 100f;
+ } else {
+ this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f);
+ this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ onlineUpdate(features, gradient);
+ }
+
+ @Override
+ protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
+ final float g_g = gradient * (gradient / scaling);
+
+ for (FeatureValue f : features) {// w[i] += y * x[i]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ IWeightValue old_w = model.get(x);
+ IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
+ model.set(x, new_w);
+ }
+ }
+
+ @Nonnull
+ protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
+ final float gradient, final float g_g) {
+ float old_w = 0.f;
+ float old_scaled_sum_sqgrad = 0.f;
+ float old_sum_squared_delta_x = 0.f;
+ if (old != null) {
+ old_w = old.get();
+ old_scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ old_sum_squared_delta_x = old.getSumOfSquaredDeltaX();
+ }
- public AdaDeltaUDTF() {
- optimizerOptions.put("optimizer", "AdaDelta");
+ float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g);
+ float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps)
+ / (old_scaled_sum_sqgrad * scaling + eps))
+ * gradient;
+ float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x)
+ + ((1.f - decay) * dx * dx);
+ float new_w = old_w + (dx * xi);
+ return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/AdaGradUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
index 4b5f019..87188fc 100644
--- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
@@ -18,14 +18,127 @@
*/
package hivemall.regression;
+import hivemall.model.FeatureValue;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue.WeightValueParamsF1;
+import hivemall.optimizer.LossFunctions;
+import hivemall.utils.lang.Primitives;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
/**
* ADAGRAD algorithm with element-wise adaptive learning rates.
+ *
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
*/
@Deprecated
-public final class AdaGradUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "train_adagrad_regr",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class AdaGradUDTF extends RegressionBaseUDTF {
+
+ private float eta;
+ private float eps;
+ private float scaling;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ StructObjectInspector oi = super.initialize(argOIs);
+ model.configureParams(true, false, false);
+ return oi;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]");
+ opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
+ opts.addOption("scale", true,
+ "Internal scaling/descaling factor for cumulative weights [100]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+ if (cl == null) {
+ this.eta = 1.f;
+ this.eps = 1.f;
+ this.scaling = 100f;
+ } else {
+ this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f);
+ this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f);
+ this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
+ }
+ return cl;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ onlineUpdate(features, gradient);
+ }
+
+ @Override
+ protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
+ final float g_g = gradient * (gradient / scaling);
+
+ for (FeatureValue f : features) {// w[i] += y * x[i]
+ if (f == null) {
+ continue;
+ }
+ Object x = f.getFeature();
+ float xi = f.getValueAsFloat();
+
+ IWeightValue old_w = model.get(x);
+ IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
+ model.set(x, new_w);
+ }
+ }
+
+ @Nonnull
+ protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
+ final float gradient, final float g_g) {
+ float old_w = 0.f;
+ float scaled_sum_sqgrad = 0.f;
+
+ if (old != null) {
+ old_w = old.get();
+ scaled_sum_sqgrad = old.getSumOfSquaredGradients();
+ }
+ scaled_sum_sqgrad += g_g;
+
+ float coeff = eta(scaled_sum_sqgrad) * gradient;
+ float new_w = old_w + (coeff * xi);
+ return new WeightValueParamsF1(new_w, scaled_sum_sqgrad);
+ }
- public AdaGradUDTF() {
- optimizerOptions.put("optimizer", "AdaGrad");
+ protected float eta(final double scaledSumOfSquaredGradients) {
+ double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling;
+ //return eta / (float) Math.sqrt(sumOfSquaredGradients);
+ return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
index 2a8b543..21a784e 100644
--- a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
@@ -40,6 +40,7 @@ public class GeneralRegressionUDTF extends RegressionBaseUDTF {
protected final Map<String, String> optimizerOptions;
public GeneralRegressionUDTF() {
+ super(true); // This enables new model interfaces
this.optimizerOptions = new HashMap<String, String>();
// Set default values
optimizerOptions.put("optimizer", "adadelta");
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/LogressUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java
index ea05da3..78e617d 100644
--- a/core/src/main/java/hivemall/regression/LogressUDTF.java
+++ b/core/src/main/java/hivemall/regression/LogressUDTF.java
@@ -18,12 +18,69 @@
*/
package hivemall.regression;
+import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.LossFunctions;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+/**
+ * @deprecated Use {@link hivemall.regression.GeneralRegressionUDTF} instead
+ */
@Deprecated
-public final class LogressUDTF extends GeneralRegressionUDTF {
+@Description(
+ name = "logress",
+ value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
+ + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
+public final class LogressUDTF extends RegressionBaseUDTF {
+
+ private EtaEstimator etaEstimator;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+ if (numArgs != 2 && numArgs != 3) {
+ throw new UDFArgumentException(
+ "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
+ }
+
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps");
+ opts.addOption("power_t", true,
+ "The exponent for inverse scaling learning rate [default 0.1]");
+ opts.addOption("eta0", true, "The initial learning rate [default 0.1]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = super.processOptions(argOIs);
+
+ this.etaEstimator = EtaEstimator.get(cl);
+ return cl;
+ }
+
+ @Override
+ protected void checkTargetValue(final float target) throws UDFArgumentException {
+ if (target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
- public LogressUDTF() {
- optimizerOptions.put("optimizer", "SGD");
- optimizerOptions.put("eta", "fixed");
+ @Override
+ protected float computeGradient(final float target, final float predicted) {
+ float eta = etaEstimator.eta(count);
+ float gradient = LossFunctions.logisticLoss(target, predicted);
+ return eta * gradient;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
index e1afe2f..3de56fd 100644
--- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
@@ -18,10 +18,10 @@
*/
package hivemall.regression;
-import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
+import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
index 7dc8538..24b0556 100644
--- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
+++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
@@ -72,6 +72,16 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
protected transient Map<Object, FloatAccumulator> accumulated;
protected int sampled;
+ private boolean enableNewModel;
+
+ public RegressionBaseUDTF() {
+ this.enableNewModel = false;
+ }
+
+ public RegressionBaseUDTF(boolean enableNewModel) {
+ this.enableNewModel = enableNewModel;
+ }
+
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2) {
@@ -85,7 +95,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
PrimitiveObjectInspector featureOutputOI = dense_model ? PrimitiveObjectInspectorFactory.javaIntObjectInspector
: featureInputOI;
- this.model = createModel();
+ this.model = enableNewModel? createNewModel(null) : createModel();
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
new file mode 100644
index 0000000..dd9c4ec
--- /dev/null
+++ b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
@@ -0,0 +1,60 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.model;
+
+import static org.junit.Assert.assertEquals;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.lang.HalfFloat;
+
+import java.util.Random;
+
+import org.junit.Test;
+
+public class NewSpaceEfficientNewDenseModelTest {
+
+ @Test
+ public void testGetSet() {
+ final int size = 1 << 12;
+
+ final NewSpaceEfficientDenseModel model1 = new NewSpaceEfficientDenseModel(size);
+ //model1.configureClock();
+ final NewDenseModel model2 = new NewDenseModel(size);
+ //model2.configureClock();
+
+ final Random rand = new Random();
+ for (int t = 0; t < 1000; t++) {
+ int i = rand.nextInt(size);
+ float f = HalfFloat.MAX_FLOAT * rand.nextFloat();
+ IWeightValue w = new WeightValue(f);
+ model1.set(i, w);
+ model2.set(i, w);
+ }
+
+ assertEquals(model2.size(), model1.size());
+
+ IMapIterator<Integer, IWeightValue> itor = model1.entries();
+ while (itor.next() != -1) {
+ int k = itor.getKey();
+ float expected = itor.getValue().get();
+ float actual = model2.getWeight(k);
+ assertEquals(expected, actual, 32f);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java b/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
deleted file mode 100644
index e3a1ed4..0000000
--- a/core/src/test/java/hivemall/model/SpaceEfficientDenseModelTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.model;
-
-import static org.junit.Assert.assertEquals;
-import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.lang.HalfFloat;
-
-import java.util.Random;
-
-import org.junit.Test;
-
-public class SpaceEfficientDenseModelTest {
-
- @Test
- public void testGetSet() {
- final int size = 1 << 12;
-
- final SpaceEfficientDenseModel model1 = new SpaceEfficientDenseModel(size);
- //model1.configureClock();
- final DenseModel model2 = new DenseModel(size);
- //model2.configureClock();
-
- final Random rand = new Random();
- for (int t = 0; t < 1000; t++) {
- int i = rand.nextInt(size);
- float f = HalfFloat.MAX_FLOAT * rand.nextFloat();
- IWeightValue w = new WeightValue(f);
- model1.set(i, w);
- model2.set(i, w);
- }
-
- assertEquals(model2.size(), model1.size());
-
- IMapIterator<Integer, IWeightValue> itor = model1.entries();
- while (itor.next() != -1) {
- int k = itor.getKey();
- float expected = itor.getValue().get();
- float actual = model2.getWeight(k);
- assertEquals(expected, actual, 32f);
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
----------------------------------------------------------------------
diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
index 38792d8..ec6d556 100644
--- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
+++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
@@ -18,9 +18,9 @@
*/
package hivemall.mix.server;
-import hivemall.model.DenseModel;
+import hivemall.model.NewDenseModel;
import hivemall.model.PredictionModel;
-import hivemall.model.SparseModel;
+import hivemall.model.NewSparseModel;
import hivemall.model.WeightValue;
import hivemall.mix.MixMessage.MixEventName;
import hivemall.mix.client.MixClient;
@@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase {
}
private static void invokeClient(String groupId, int serverPort) throws InterruptedException {
- PredictionModel model = new DenseModel(16777216);
+ PredictionModel model = new NewDenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -298,8 +298,8 @@ public class MixServerTest extends HivemallTestBase {
private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix)
throws InterruptedException {
- PredictionModel model = denseModel ? new DenseModel(100)
- : new SparseModel(100, false);
+ PredictionModel model = denseModel ? new NewDenseModel(100)
+ : new NewSparseModel(100, false);
model.configureClock();
MixClient client = null;
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3620eb89/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
index 4fb74f1..c0ee72f 100644
--- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
+++ b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
@@ -23,7 +23,7 @@ import java.util.logging.Logger
import org.scalatest.{BeforeAndAfter, FunSuite}
-import hivemall.model.{DenseModel, PredictionModel, WeightValue}
+import hivemall.model.{NewDenseModel, PredictionModel, WeightValue}
import hivemall.mix.MixMessage.MixEventName
import hivemall.mix.client.MixClient
import hivemall.mix.server.MixServer.ServerState
@@ -95,7 +95,7 @@ class MixServerSuite extends FunSuite with BeforeAndAfter {
ignore(testName) {
val clients = Executors.newCachedThreadPool()
val numClients = nclient
- val models = (0 until numClients).map(i => new DenseModel(ndims, false))
+ val models = (0 until numClients).map(i => new NewDenseModel(ndims, false))
(0 until numClients).map { i =>
clients.submit(new Runnable() {
override def run(): Unit = {
[12/50] [abbrv] incubator-hivemall git commit: Add optimizer
implementations
Posted by my...@apache.org.
Add optimizer implementations
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f81948c5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f81948c5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f81948c5
Branch: refs/heads/JIRA-22/pr-285
Commit: f81948c5c7b83155eb29369a59f1fc65bb607f91
Parents: 5a7df55
Author: Takeshi YAMAMURO <li...@gmail.com>
Authored: Mon May 2 23:43:42 2016 +0900
Committer: Takeshi YAMAMURO <li...@gmail.com>
Committed: Wed Sep 21 00:07:28 2016 +0900
----------------------------------------------------------------------
.../src/main/java/hivemall/LearnerBaseUDTF.java | 22 +
.../hivemall/classifier/AROWClassifierUDTF.java | 2 +-
.../hivemall/classifier/AdaGradRDAUDTF.java | 123 +----
.../classifier/BinaryOnlineClassifierUDTF.java | 3 +
.../classifier/GeneralClassifierUDTF.java | 121 +++++
.../classifier/PassiveAggressiveUDTF.java | 2 +-
.../main/java/hivemall/common/EtaEstimator.java | 160 -------
.../java/hivemall/common/LossFunctions.java | 467 -------------------
.../java/hivemall/fm/FMHyperParameters.java | 2 +-
.../hivemall/fm/FactorizationMachineModel.java | 2 +-
.../hivemall/fm/FactorizationMachineUDTF.java | 8 +-
.../fm/FieldAwareFactorizationMachineModel.java | 1 +
.../hivemall/mf/BPRMatrixFactorizationUDTF.java | 2 +-
.../hivemall/mf/MatrixFactorizationSGDUDTF.java | 2 +-
.../main/java/hivemall/model/DenseModel.java | 87 +---
.../main/java/hivemall/model/IWeightValue.java | 16 +-
.../java/hivemall/model/PredictionModel.java | 5 +-
.../model/SpaceEfficientDenseModel.java | 93 +---
.../main/java/hivemall/model/SparseModel.java | 20 +-
.../model/SynchronizedModelWrapper.java | 16 +-
.../main/java/hivemall/model/WeightValue.java | 162 ++++++-
.../hivemall/model/WeightValueWithClock.java | 167 ++++++-
.../optimizer/DenseOptimizerFactory.java | 215 +++++++++
.../java/hivemall/optimizer/EtaEstimator.java | 191 ++++++++
.../java/hivemall/optimizer/LossFunctions.java | 467 +++++++++++++++++++
.../main/java/hivemall/optimizer/Optimizer.java | 246 ++++++++++
.../java/hivemall/optimizer/Regularization.java | 99 ++++
.../optimizer/SparseOptimizerFactory.java | 171 +++++++
.../hivemall/regression/AROWRegressionUDTF.java | 2 +-
.../java/hivemall/regression/AdaDeltaUDTF.java | 117 +----
.../java/hivemall/regression/AdaGradUDTF.java | 118 +----
.../regression/GeneralRegressionUDTF.java | 125 +++++
.../java/hivemall/regression/LogressUDTF.java | 63 +--
.../PassiveAggressiveRegressionUDTF.java | 2 +-
.../hivemall/regression/RegressionBaseUDTF.java | 14 +-
.../java/hivemall/optimizer/OptimizerTest.java | 172 +++++++
.../java/hivemall/mix/server/MixServerTest.java | 14 +-
resources/ddl/define-all-as-permanent.hive | 13 +-
resources/ddl/define-all.hive | 12 +-
39 files changed, 2301 insertions(+), 1223 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/LearnerBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/LearnerBaseUDTF.java b/core/src/main/java/hivemall/LearnerBaseUDTF.java
index 4518cce..7fd5190 100644
--- a/core/src/main/java/hivemall/LearnerBaseUDTF.java
+++ b/core/src/main/java/hivemall/LearnerBaseUDTF.java
@@ -28,6 +28,9 @@ import hivemall.model.SparseModel;
import hivemall.model.SynchronizedModelWrapper;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.DenseOptimizerFactory;
+import hivemall.optimizer.Optimizer;
+import hivemall.optimizer.SparseOptimizerFactory;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
@@ -38,6 +41,7 @@ import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.List;
+import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -195,6 +199,24 @@ public abstract class LearnerBaseUDTF extends UDTFWithOptions {
return model;
}
+ // If a model implements a optimizer, it must override this
+ protected Map<String, String> getOptimzierOptions() {
+ return null;
+ }
+
+ protected Optimizer createOptimizer() {
+ assert(!useCovariance());
+ final Map<String, String> options = getOptimzierOptions();
+ if(options != null) {
+ if (dense_model) {
+ return DenseOptimizerFactory.create(model_dims, options);
+ } else {
+ return SparseOptimizerFactory.create(model_dims, options);
+ }
+ }
+ return null;
+ }
+
protected MixClient configureMixClient(String connectURIs, String label, PredictionModel model) {
assert (connectURIs != null);
assert (model != null);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
index e5ef975..ac8afcb 100644
--- a/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.classifier;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionResult;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
index 1351bca..a6714f4 100644
--- a/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
+++ b/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
@@ -18,124 +18,13 @@
*/
package hivemall.classifier;
-import hivemall.common.LossFunctions;
-import hivemall.model.FeatureValue;
-import hivemall.model.IWeightValue;
-import hivemall.model.WeightValue.WeightValueParamsF2;
-import hivemall.utils.lang.Primitives;
+@Deprecated
+public final class AdaGradRDAUDTF extends GeneralClassifierUDTF {
-import javax.annotation.Nonnull;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
-@Description(name = "train_adagrad_rda",
- value = "_FUNC_(list<string|int|bigint> features, int label [, const string options])"
- + " - Returns a relation consists of <string|int|bigint feature, float weight>",
- extended = "Build a prediction model by Adagrad+RDA regularization binary classifier")
-public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF {
-
- private float eta;
- private float lambda;
- private float scaling;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]");
- }
-
- StructObjectInspector oi = super.initialize(argOIs);
- model.configureParams(true, false, true);
- return oi;
+ public AdaGradRDAUDTF() {
+ optimizerOptions.put("optimizer", "AdaGrad");
+ optimizerOptions.put("regularization", "RDA");
+ optimizerOptions.put("lambda", "1e-6");
}
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]");
- opts.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]");
- opts.addOption("scale", true,
- "Internal scaling/descaling factor for cumulative weights [default: 100]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
- if (cl == null) {
- this.eta = 0.1f;
- this.lambda = 1E-6f;
- this.scaling = 100f;
- } else {
- this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.1f);
- this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 1E-6f);
- this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
- }
- return cl;
- }
-
- @Override
- protected void train(@Nonnull final FeatureValue[] features, final int label) {
- final float y = label > 0 ? 1.f : -1.f;
-
- float p = predict(features);
- float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p
- if (loss <= 0.f) { // max(0, 1 - y * p)
- return;
- }
- // subgradient => -y * W dot xi
- update(features, y, count);
- }
-
- protected void update(@Nonnull final FeatureValue[] features, final float y, final int t) {
- for (FeatureValue f : features) {// w[f] += y * x[f]
- if (f == null) {
- continue;
- }
- Object x = f.getFeature();
- float xi = f.getValueAsFloat();
-
- updateWeight(x, xi, y, t);
- }
- }
-
- protected void updateWeight(@Nonnull final Object x, final float xi, final float y,
- final float t) {
- final float gradient = -y * xi;
- final float scaled_gradient = gradient * scaling;
-
- float scaled_sum_sqgrad = 0.f;
- float scaled_sum_grad = 0.f;
- IWeightValue old = model.get(x);
- if (old != null) {
- scaled_sum_sqgrad = old.getSumOfSquaredGradients();
- scaled_sum_grad = old.getSumOfGradients();
- }
- scaled_sum_grad += scaled_gradient;
- scaled_sum_sqgrad += (scaled_gradient * scaled_gradient);
-
- float sum_grad = scaled_sum_grad * scaling;
- double sum_sqgrad = scaled_sum_sqgrad * scaling;
-
- // sign(u_{t,i})
- float sign = (sum_grad > 0.f) ? 1.f : -1.f;
- // |u_{t,i}|/t - \lambda
- float meansOfGradients = sign * sum_grad / t - lambda;
- if (meansOfGradients < 0.f) {
- // x_{t,i} = 0
- model.delete(x);
- } else {
- // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda)
- float weight = -1.f * sign * eta * t * meansOfGradients / (float) Math.sqrt(sum_sqgrad);
- IWeightValue new_w = new WeightValueParamsF2(weight, scaled_sum_sqgrad, scaled_sum_grad);
- model.set(x, new_w);
- }
- }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
index 43a124d..0ee5d5f 100644
--- a/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
+++ b/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
@@ -25,6 +25,7 @@ import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.Optimizer;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
@@ -56,6 +57,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
private boolean parseFeature;
protected PredictionModel model;
+ protected Optimizer optimizerImpl;
protected int count;
@Override
@@ -76,6 +78,7 @@ public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}
+ this.optimizerImpl = createOptimizer();
this.count = 0;
return getReturnOI(featureOutputOI);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
new file mode 100644
index 0000000..feebadd
--- /dev/null
+++ b/core/src/main/java/hivemall/classifier/GeneralClassifierUDTF.java
@@ -0,0 +1,121 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.classifier;
+
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+import hivemall.optimizer.LossFunctions;
+import hivemall.model.FeatureValue;
+
+/**
+ * A general classifier class with replaceable optimization functions.
+ */
+public class GeneralClassifierUDTF extends BinaryOnlineClassifierUDTF {
+
+ protected final Map<String, String> optimizerOptions;
+
+ public GeneralClassifierUDTF() {
+ this.optimizerOptions = new HashMap<String, String>();
+ // Set default values
+ optimizerOptions.put("optimizer", "adagrad");
+ optimizerOptions.put("eta", "fixed");
+ optimizerOptions.put("eta0", "1.0");
+ optimizerOptions.put("regularization", "RDA");
+ optimizerOptions.put("lambda", "1e-6");
+ optimizerOptions.put("scale", "100.0");
+ optimizerOptions.put("lambda", "1.0");
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if(argOIs.length != 2 && argOIs.length != 3) {
+ throw new UDFArgumentException(
+ this.getClass().getSimpleName()
+ + " takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label "
+ + "[, constant string options]");
+ }
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("optimizer", "opt", true, "Optimizer to update weights [default: adagrad+rda]");
+ opts.addOption("eta", "eta0", true, "Initial learning rate [default 1.0]");
+ opts.addOption("lambda", true, "Lambda value of RDA [default: 1e-6f]");
+ opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]");
+ opts.addOption("regularization", "reg", true, "Regularization type [default not-defined]");
+ opts.addOption("lambda", true, "Regularization term on weights [default 1.0]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final CommandLine cl = super.processOptions(argOIs);
+ assert(cl != null);
+ if(cl != null) {
+ for(final String arg : cl.getArgs()) {
+ optimizerOptions.put(arg, cl.getOptionValue(arg));
+ }
+ }
+ return cl;
+ }
+
+ @Override
+ protected Map<String, String> getOptimzierOptions() {
+ return optimizerOptions;
+ }
+
+ @Override
+ protected void train(@Nonnull final FeatureValue[] features, final int label) {
+ float predicted = predict(features);
+ update(features, label > 0 ? 1.f : -1.f, predicted);
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, final float label,
+ final float predicted) {
+ if(is_mini_batch) {
+ throw new UnsupportedOperationException(
+ this.getClass().getSimpleName() + " supports no `is_mini_batch` mode");
+ } else {
+ float loss = LossFunctions.hingeLoss(predicted, label);
+ if(loss <= 0.f) {
+ return;
+ }
+ for(FeatureValue f : features) {
+ Object feature = f.getFeature();
+ float xi = f.getValueAsFloat();
+ float weight = model.getWeight(feature);
+ float new_weight = optimizerImpl.computeUpdatedValue(feature, weight, -label * xi);
+ model.setWeight(feature, new_weight);
+ }
+ optimizerImpl.proceedStep();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
index 0213dec..9e404cd 100644
--- a/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
+++ b/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.classifier;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/common/EtaEstimator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/common/EtaEstimator.java b/core/src/main/java/hivemall/common/EtaEstimator.java
deleted file mode 100644
index 3287641..0000000
--- a/core/src/main/java/hivemall/common/EtaEstimator.java
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.common;
-
-import hivemall.utils.lang.NumberUtils;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnegative;
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-
-public abstract class EtaEstimator {
-
- protected final float eta0;
-
- public EtaEstimator(float eta0) {
- this.eta0 = eta0;
- }
-
- public float eta0() {
- return eta0;
- }
-
- public abstract float eta(long t);
-
- public void update(@Nonnegative float multipler) {}
-
- public static final class FixedEtaEstimator extends EtaEstimator {
-
- public FixedEtaEstimator(float eta) {
- super(eta);
- }
-
- @Override
- public float eta(long t) {
- return eta0;
- }
-
- }
-
- public static final class SimpleEtaEstimator extends EtaEstimator {
-
- private final float finalEta;
- private final double total_steps;
-
- public SimpleEtaEstimator(float eta0, long total_steps) {
- super(eta0);
- this.finalEta = (float) (eta0 / 2.d);
- this.total_steps = total_steps;
- }
-
- @Override
- public float eta(final long t) {
- if (t > total_steps) {
- return finalEta;
- }
- return (float) (eta0 / (1.d + (t / total_steps)));
- }
-
- }
-
- public static final class InvscalingEtaEstimator extends EtaEstimator {
-
- private final double power_t;
-
- public InvscalingEtaEstimator(float eta0, double power_t) {
- super(eta0);
- this.power_t = power_t;
- }
-
- @Override
- public float eta(final long t) {
- return (float) (eta0 / Math.pow(t, power_t));
- }
-
- }
-
- /**
- * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic
- * gradient descent, KDD 2011.
- */
- public static final class AdjustingEtaEstimator extends EtaEstimator {
-
- private float eta;
-
- public AdjustingEtaEstimator(float eta) {
- super(eta);
- this.eta = eta;
- }
-
- @Override
- public float eta(long t) {
- return eta;
- }
-
- @Override
- public void update(@Nonnegative float multipler) {
- float newEta = eta * multipler;
- if (!NumberUtils.isFinite(newEta)) {
- // avoid NaN or INFINITY
- return;
- }
- this.eta = Math.min(eta0, newEta); // never be larger than eta0
- }
-
- }
-
- @Nonnull
- public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
- return get(cl, 0.1f);
- }
-
- @Nonnull
- public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0)
- throws UDFArgumentException {
- if (cl == null) {
- return new InvscalingEtaEstimator(defaultEta0, 0.1d);
- }
-
- if (cl.hasOption("boldDriver")) {
- float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f);
- return new AdjustingEtaEstimator(eta);
- }
-
- String etaValue = cl.getOptionValue("eta");
- if (etaValue != null) {
- float eta = Float.parseFloat(etaValue);
- return new FixedEtaEstimator(eta);
- }
-
- float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
- if (cl.hasOption("t")) {
- long t = Long.parseLong(cl.getOptionValue("t"));
- return new SimpleEtaEstimator(eta0, t);
- }
-
- double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d);
- return new InvscalingEtaEstimator(eta0, power_t);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/common/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/common/LossFunctions.java b/core/src/main/java/hivemall/common/LossFunctions.java
deleted file mode 100644
index 6b403fd..0000000
--- a/core/src/main/java/hivemall/common/LossFunctions.java
+++ /dev/null
@@ -1,467 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.common;
-
-import hivemall.utils.math.MathUtils;
-
-/**
- * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions
- */
-public final class LossFunctions {
-
- public enum LossType {
- SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss
- }
-
- public static LossFunction getLossFunction(String type) {
- if ("SquaredLoss".equalsIgnoreCase(type)) {
- return new SquaredLoss();
- } else if ("LogLoss".equalsIgnoreCase(type)) {
- return new LogLoss();
- } else if ("HingeLoss".equalsIgnoreCase(type)) {
- return new HingeLoss();
- } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
- return new SquaredHingeLoss();
- } else if ("QuantileLoss".equalsIgnoreCase(type)) {
- return new QuantileLoss();
- } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
- return new EpsilonInsensitiveLoss();
- }
- throw new IllegalArgumentException("Unsupported type: " + type);
- }
-
- public static LossFunction getLossFunction(LossType type) {
- switch (type) {
- case SquaredLoss:
- return new SquaredLoss();
- case LogLoss:
- return new LogLoss();
- case HingeLoss:
- return new HingeLoss();
- case SquaredHingeLoss:
- return new SquaredHingeLoss();
- case QuantileLoss:
- return new QuantileLoss();
- case EpsilonInsensitiveLoss:
- return new EpsilonInsensitiveLoss();
- default:
- throw new IllegalArgumentException("Unsupported type: " + type);
- }
- }
-
- public interface LossFunction {
-
- /**
- * Evaluate the loss function.
- *
- * @param p The prediction, p = w^T x
- * @param y The true value (aka target)
- * @return The loss evaluated at `p` and `y`.
- */
- public float loss(float p, float y);
-
- public double loss(double p, double y);
-
- /**
- * Evaluate the derivative of the loss function with respect to the prediction `p`.
- *
- * @param p The prediction, p = w^T x
- * @param y The true value (aka target)
- * @return The derivative of the loss function w.r.t. `p`.
- */
- public float dloss(float p, float y);
-
- public boolean forBinaryClassification();
-
- public boolean forRegression();
-
- }
-
- public static abstract class BinaryLoss implements LossFunction {
-
- protected static void checkTarget(float y) {
- if (!(y == 1.f || y == -1.f)) {
- throw new IllegalArgumentException("target must be [+1,-1]: " + y);
- }
- }
-
- protected static void checkTarget(double y) {
- if (!(y == 1.d || y == -1.d)) {
- throw new IllegalArgumentException("target must be [+1,-1]: " + y);
- }
- }
-
- @Override
- public boolean forBinaryClassification() {
- return true;
- }
-
- @Override
- public boolean forRegression() {
- return false;
- }
- }
-
- public static abstract class RegressionLoss implements LossFunction {
-
- @Override
- public boolean forBinaryClassification() {
- return false;
- }
-
- @Override
- public boolean forRegression() {
- return true;
- }
-
- }
-
- /**
- * Squared loss for regression problems.
- *
- * If you're trying to minimize the mean error, use squared-loss.
- */
- public static final class SquaredLoss extends RegressionLoss {
-
- @Override
- public float loss(float p, float y) {
- final float z = p - y;
- return z * z * 0.5f;
- }
-
- @Override
- public double loss(double p, double y) {
- final double z = p - y;
- return z * z * 0.5d;
- }
-
- @Override
- public float dloss(float p, float y) {
- return p - y; // 2 (p - y) / 2
- }
- }
-
- /**
- * Logistic regression loss for binary classification with y in {-1, 1}.
- */
- public static final class LogLoss extends BinaryLoss {
-
- /**
- * <code>logloss(p,y) = log(1+exp(-p*y))</code>
- */
- @Override
- public float loss(float p, float y) {
- checkTarget(y);
-
- final float z = y * p;
- if (z > 18.f) {
- return (float) Math.exp(-z);
- }
- if (z < -18.f) {
- return -z;
- }
- return (float) Math.log(1.d + Math.exp(-z));
- }
-
- @Override
- public double loss(double p, double y) {
- checkTarget(y);
-
- final double z = y * p;
- if (z > 18.d) {
- return Math.exp(-z);
- }
- if (z < -18.d) {
- return -z;
- }
- return Math.log(1.d + Math.exp(-z));
- }
-
- @Override
- public float dloss(float p, float y) {
- checkTarget(y);
-
- float z = y * p;
- if (z > 18.f) {
- return (float) Math.exp(-z) * -y;
- }
- if (z < -18.f) {
- return -y;
- }
- return -y / ((float) Math.exp(z) + 1.f);
- }
- }
-
- /**
- * Hinge loss for binary classification tasks with y in {-1,1}.
- */
- public static final class HingeLoss extends BinaryLoss {
-
- private float threshold;
-
- public HingeLoss() {
- this(1.f);
- }
-
- /**
- * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM.
- * When threshold=0.0, one gets the loss used by the Perceptron.
- */
- public HingeLoss(float threshold) {
- this.threshold = threshold;
- }
-
- public void setThreshold(float threshold) {
- this.threshold = threshold;
- }
-
- @Override
- public float loss(float p, float y) {
- float loss = hingeLoss(p, y, threshold);
- return (loss > 0.f) ? loss : 0.f;
- }
-
- @Override
- public double loss(double p, double y) {
- double loss = hingeLoss(p, y, threshold);
- return (loss > 0.d) ? loss : 0.d;
- }
-
- @Override
- public float dloss(float p, float y) {
- float loss = hingeLoss(p, y, threshold);
- return (loss > 0.f) ? -y : 0.f;
- }
- }
-
- /**
- * Squared Hinge loss for binary classification tasks with y in {-1,1}.
- */
- public static final class SquaredHingeLoss extends BinaryLoss {
-
- @Override
- public float loss(float p, float y) {
- return squaredHingeLoss(p, y);
- }
-
- @Override
- public double loss(double p, double y) {
- return squaredHingeLoss(p, y);
- }
-
- @Override
- public float dloss(float p, float y) {
- checkTarget(y);
-
- float d = 1 - (y * p);
- return (d > 0.f) ? -2.f * d * y : 0.f;
- }
-
- }
-
- /**
- * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase
- * as long as you get the relative order correct.
- *
- * @link http://en.wikipedia.org/wiki/Quantile_regression
- */
- public static final class QuantileLoss extends RegressionLoss {
-
- private float tau;
-
- public QuantileLoss() {
- this.tau = 0.5f;
- }
-
- public QuantileLoss(float tau) {
- setTau(tau);
- }
-
- public void setTau(float tau) {
- if (tau <= 0 || tau >= 1.0) {
- throw new IllegalArgumentException("tau must be in range (0, 1): " + tau);
- }
- this.tau = tau;
- }
-
- @Override
- public float loss(float p, float y) {
- float e = y - p;
- if (e > 0.f) {
- return tau * e;
- } else {
- return -(1.f - tau) * e;
- }
- }
-
- @Override
- public double loss(double p, double y) {
- double e = y - p;
- if (e > 0.d) {
- return tau * e;
- } else {
- return -(1.d - tau) * e;
- }
- }
-
- @Override
- public float dloss(float p, float y) {
- float e = y - p;
- if (e == 0.f) {
- return 0.f;
- }
- return (e > 0.f) ? -tau : (1.f - tau);
- }
-
- }
-
- /**
- * Epsilon-Insensitive loss used by Support Vector Regression (SVR).
- * <code>loss = max(0, |y - p| - epsilon)</code>
- */
- public static final class EpsilonInsensitiveLoss extends RegressionLoss {
-
- private float epsilon;
-
- public EpsilonInsensitiveLoss() {
- this(0.1f);
- }
-
- public EpsilonInsensitiveLoss(float epsilon) {
- this.epsilon = epsilon;
- }
-
- public void setEpsilon(float epsilon) {
- this.epsilon = epsilon;
- }
-
- @Override
- public float loss(float p, float y) {
- float loss = Math.abs(y - p) - epsilon;
- return (loss > 0.f) ? loss : 0.f;
- }
-
- @Override
- public double loss(double p, double y) {
- double loss = Math.abs(y - p) - epsilon;
- return (loss > 0.d) ? loss : 0.d;
- }
-
- @Override
- public float dloss(float p, float y) {
- if ((y - p) > epsilon) {// real value > predicted value - epsilon
- return -1.f;
- }
- if ((p - y) > epsilon) {// real value < predicted value - epsilon
- return 1.f;
- }
- return 0.f;
- }
-
- }
-
- public static float logisticLoss(final float target, final float predicted) {
- if (predicted > -100.d) {
- return target - (float) MathUtils.sigmoid(predicted);
- } else {
- return target;
- }
- }
-
- public static float logLoss(final float p, final float y) {
- BinaryLoss.checkTarget(y);
-
- final float z = y * p;
- if (z > 18.f) {
- return (float) Math.exp(-z);
- }
- if (z < -18.f) {
- return -z;
- }
- return (float) Math.log(1.d + Math.exp(-z));
- }
-
- public static double logLoss(final double p, final double y) {
- BinaryLoss.checkTarget(y);
-
- final double z = y * p;
- if (z > 18.d) {
- return Math.exp(-z);
- }
- if (z < -18.d) {
- return -z;
- }
- return Math.log(1.d + Math.exp(-z));
- }
-
- public static float squaredLoss(float p, float y) {
- final float z = p - y;
- return z * z * 0.5f;
- }
-
- public static double squaredLoss(double p, double y) {
- final double z = p - y;
- return z * z * 0.5d;
- }
-
- public static float hingeLoss(final float p, final float y, final float threshold) {
- BinaryLoss.checkTarget(y);
-
- float z = y * p;
- return threshold - z;
- }
-
- public static double hingeLoss(final double p, final double y, final double threshold) {
- BinaryLoss.checkTarget(y);
-
- double z = y * p;
- return threshold - z;
- }
-
- public static float hingeLoss(float p, float y) {
- return hingeLoss(p, y, 1.f);
- }
-
- public static double hingeLoss(double p, double y) {
- return hingeLoss(p, y, 1.d);
- }
-
- public static float squaredHingeLoss(final float p, final float y) {
- BinaryLoss.checkTarget(y);
-
- float z = y * p;
- float d = 1.f - z;
- return (d > 0.f) ? (d * d) : 0.f;
- }
-
- public static double squaredHingeLoss(final double p, final double y) {
- BinaryLoss.checkTarget(y);
-
- double z = y * p;
- double d = 1.d - z;
- return (d > 0.d) ? d * d : 0.d;
- }
-
- /**
- * Math.abs(target - predicted) - epsilon
- */
- public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
- return Math.abs(target - predicted) - epsilon;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FMHyperParameters.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FMHyperParameters.java b/core/src/main/java/hivemall/fm/FMHyperParameters.java
index db69db3..512476d 100644
--- a/core/src/main/java/hivemall/fm/FMHyperParameters.java
+++ b/core/src/main/java/hivemall/fm/FMHyperParameters.java
@@ -17,8 +17,8 @@
*/
package hivemall.fm;
-import hivemall.common.EtaEstimator;
import hivemall.fm.FactorizationMachineModel.VInitScheme;
+import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnull;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
index 396328a..4b6ece6 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
@@ -18,7 +18,7 @@
*/
package hivemall.fm;
-import hivemall.common.EtaEstimator;
+import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
index 2388689..7739c52 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
@@ -20,10 +20,10 @@ package hivemall.fm;
import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
-import hivemall.common.EtaEstimator;
-import hivemall.common.LossFunctions;
-import hivemall.common.LossFunctions.LossFunction;
-import hivemall.common.LossFunctions.LossType;
+import hivemall.optimizer.EtaEstimator;
+import hivemall.optimizer.LossFunctions;
+import hivemall.optimizer.LossFunctions.LossFunction;
+import hivemall.optimizer.LossFunctions.LossType;
import hivemall.fm.FMStringFeatureMapModel.Entry;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
index 7e3cc50..fde7701 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
@@ -21,6 +21,7 @@ package hivemall.fm;
import hivemall.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.collections.DoubleArray3D;
import hivemall.utils.collections.IntArrayList;
+import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.NumberUtils;
import java.util.Arrays;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
index d859f29..87d2654 100644
--- a/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
+++ b/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
@@ -20,7 +20,7 @@ package hivemall.mf;
import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
-import hivemall.common.EtaEstimator;
+import hivemall.optimizer.EtaEstimator;
import hivemall.mf.FactorizedModel.RankInitScheme;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java
index 317da85..ab79ce2 100644
--- a/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java
+++ b/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.mf;
-import hivemall.common.EtaEstimator;
+import hivemall.optimizer.EtaEstimator;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/DenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/DenseModel.java b/core/src/main/java/hivemall/model/DenseModel.java
index ee57574..6956875 100644
--- a/core/src/main/java/hivemall/model/DenseModel.java
+++ b/core/src/main/java/hivemall/model/DenseModel.java
@@ -18,21 +18,18 @@
*/
package hivemall.model;
-import hivemall.model.WeightValue.WeightValueParamsF1;
-import hivemall.model.WeightValue.WeightValueParamsF2;
-import hivemall.model.WeightValue.WeightValueWithCovar;
-import hivemall.utils.collections.IMapIterator;
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.Copyable;
-import hivemall.utils.math.MathUtils;
-
import java.util.Arrays;
-
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Copyable;
+import hivemall.utils.math.MathUtils;
+
public final class DenseModel extends AbstractPredictionModel {
private static final Log logger = LogFactory.getLog(DenseModel.class);
@@ -40,13 +37,6 @@ public final class DenseModel extends AbstractPredictionModel {
private float[] weights;
private float[] covars;
- // optional values for adagrad
- private float[] sum_of_squared_gradients;
- // optional value for adadelta
- private float[] sum_of_squared_delta_x;
- // optional value for adagrad+rda
- private float[] sum_of_gradients;
-
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -67,9 +57,6 @@ public final class DenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
- this.sum_of_squared_gradients = null;
- this.sum_of_squared_delta_x = null;
- this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -85,20 +72,6 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
- public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
- boolean sum_of_gradients) {
- if (sum_of_squared_gradients) {
- this.sum_of_squared_gradients = new float[size];
- }
- if (sum_of_squared_delta_x) {
- this.sum_of_squared_delta_x = new float[size];
- }
- if (sum_of_gradients) {
- this.sum_of_gradients = new float[size];
- }
- }
-
- @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -129,16 +102,7 @@ public final class DenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, 1.f);
}
- if (sum_of_squared_gradients != null) {
- this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
- }
- if (sum_of_squared_delta_x != null) {
- this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
- }
- if (sum_of_gradients != null) {
- this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
- }
- if (clocks != null) {
+ if(clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -152,17 +116,7 @@ public final class DenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
- if (sum_of_squared_gradients != null) {
- if (sum_of_squared_delta_x != null) {
- return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
- sum_of_squared_delta_x[i]);
- } else if (sum_of_gradients != null) {
- return (T) new WeightValueParamsF2(weights[i], sum_of_squared_gradients[i],
- sum_of_gradients[i]);
- } else {
- return (T) new WeightValueParamsF1(weights[i], sum_of_squared_gradients[i]);
- }
- } else if (covars != null) {
+ if(covars != null) {
return (T) new WeightValueWithCovar(weights[i], covars[i]);
} else {
return (T) new WeightValue(weights[i]);
@@ -181,15 +135,6 @@ public final class DenseModel extends AbstractPredictionModel {
covar = value.getCovariance();
covars[i] = covar;
}
- if (sum_of_squared_gradients != null) {
- sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
- }
- if (sum_of_squared_delta_x != null) {
- sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
- }
- if (sum_of_gradients != null) {
- sum_of_gradients[i] = value.getSumOfGradients();
- }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -213,15 +158,6 @@ public final class DenseModel extends AbstractPredictionModel {
if (covars != null) {
covars[i] = 1.f;
}
- if (sum_of_squared_gradients != null) {
- sum_of_squared_gradients[i] = 0.f;
- }
- if (sum_of_squared_delta_x != null) {
- sum_of_squared_delta_x[i] = 0.f;
- }
- if (sum_of_gradients != null) {
- sum_of_gradients[i] = 0.f;
- }
// avoid clock/delta
}
@@ -235,6 +171,13 @@ public final class DenseModel extends AbstractPredictionModel {
}
@Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weights[i] = value;
+ }
+
+ @Override
public float getCovariance(Object feature) {
int i = HiveUtils.parseInt(feature);
if (i >= size) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/IWeightValue.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/IWeightValue.java b/core/src/main/java/hivemall/model/IWeightValue.java
index 988e4a1..259628f 100644
--- a/core/src/main/java/hivemall/model/IWeightValue.java
+++ b/core/src/main/java/hivemall/model/IWeightValue.java
@@ -25,7 +25,7 @@ import javax.annotation.Nonnegative;
public interface IWeightValue extends Copyable<IWeightValue> {
public enum WeightValueType {
- NoParams, ParamsF1, ParamsF2, ParamsCovar;
+ NoParams, ParamsF1, ParamsF2, ParamsF3, ParamsCovar;
}
WeightValueType getType();
@@ -44,10 +44,24 @@ public interface IWeightValue extends Copyable<IWeightValue> {
float getSumOfSquaredGradients();
+ void setSumOfSquaredGradients(float value);
+
float getSumOfSquaredDeltaX();
+ void setSumOfSquaredDeltaX(float value);
+
float getSumOfGradients();
+ void setSumOfGradients(float value);
+
+ float getM();
+
+ void setM(float value);
+
+ float getV();
+
+ void setV(float value);
+
/**
* @return whether touched in training or not
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/PredictionModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/PredictionModel.java b/core/src/main/java/hivemall/model/PredictionModel.java
index a8efee0..8d8dd2b 100644
--- a/core/src/main/java/hivemall/model/PredictionModel.java
+++ b/core/src/main/java/hivemall/model/PredictionModel.java
@@ -34,9 +34,6 @@ public interface PredictionModel extends MixedModel {
boolean hasCovariance();
- void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
- boolean sum_of_gradients);
-
void configureClock();
boolean hasClock();
@@ -56,6 +53,8 @@ public interface PredictionModel extends MixedModel {
float getWeight(@Nonnull Object feature);
+ void setWeight(@Nonnull Object feature, float value);
+
float getCovariance(@Nonnull Object feature);
<K, V extends IWeightValue> IMapIterator<K, V> entries();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
index b3cd3ff..8b668e7 100644
--- a/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
+++ b/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
@@ -18,8 +18,6 @@
*/
package hivemall.model;
-import hivemall.model.WeightValue.WeightValueParamsF1;
-import hivemall.model.WeightValue.WeightValueParamsF2;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
@@ -28,7 +26,6 @@ import hivemall.utils.lang.HalfFloat;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
-
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
@@ -41,13 +38,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
private short[] weights;
private short[] covars;
- // optional value for adagrad
- private float[] sum_of_squared_gradients;
- // optional value for adadelta
- private float[] sum_of_squared_delta_x;
- // optional value for adagrad+rda
- private float[] sum_of_gradients;
-
// optional value for MIX
private short[] clocks;
private byte[] deltaUpdates;
@@ -68,9 +58,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
} else {
this.covars = null;
}
- this.sum_of_squared_gradients = null;
- this.sum_of_squared_delta_x = null;
- this.sum_of_gradients = null;
this.clocks = null;
this.deltaUpdates = null;
}
@@ -86,20 +73,6 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
- public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
- boolean sum_of_gradients) {
- if (sum_of_squared_gradients) {
- this.sum_of_squared_gradients = new float[size];
- }
- if (sum_of_squared_delta_x) {
- this.sum_of_squared_delta_x = new float[size];
- }
- if (sum_of_gradients) {
- this.sum_of_gradients = new float[size];
- }
- }
-
- @Override
public void configureClock() {
if (clocks == null) {
this.clocks = new short[size];
@@ -126,8 +99,11 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
return HalfFloat.halfFloatToFloat(covars[i]);
}
- private void setWeight(final int i, final float v) {
- HalfFloat.checkRange(v);
+ private void _setWeight(final int i, final float v) {
+ if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
+ throw new IllegalArgumentException("Acceptable maximum weight is "
+ + HalfFloat.MAX_FLOAT + ": " + v);
+ }
weights[i] = HalfFloat.floatToHalfFloat(v);
}
@@ -149,16 +125,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
this.covars = Arrays.copyOf(covars, newSize);
Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
}
- if (sum_of_squared_gradients != null) {
- this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
- }
- if (sum_of_squared_delta_x != null) {
- this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
- }
- if (sum_of_gradients != null) {
- this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
- }
- if (clocks != null) {
+ if(clocks != null) {
this.clocks = Arrays.copyOf(clocks, newSize);
this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
}
@@ -172,17 +139,8 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return null;
}
- if (sum_of_squared_gradients != null) {
- if (sum_of_squared_delta_x != null) {
- return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
- sum_of_squared_delta_x[i]);
- } else if (sum_of_gradients != null) {
- return (T) new WeightValueParamsF2(getWeight(i), sum_of_squared_gradients[i],
- sum_of_gradients[i]);
- } else {
- return (T) new WeightValueParamsF1(getWeight(i), sum_of_squared_gradients[i]);
- }
- } else if (covars != null) {
+
+ if(covars != null) {
return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
} else {
return (T) new WeightValue(getWeight(i));
@@ -194,22 +152,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
int i = HiveUtils.parseInt(feature);
ensureCapacity(i);
float weight = value.get();
- setWeight(i, weight);
+ _setWeight(i, weight);
float covar = 1.f;
boolean hasCovar = value.hasCovariance();
if (hasCovar) {
covar = value.getCovariance();
setCovar(i, covar);
}
- if (sum_of_squared_gradients != null) {
- sum_of_squared_gradients[i] = value.getSumOfSquaredGradients();
- }
- if (sum_of_squared_delta_x != null) {
- sum_of_squared_delta_x[i] = value.getSumOfSquaredDeltaX();
- }
- if (sum_of_gradients != null) {
- sum_of_gradients[i] = value.getSumOfGradients();
- }
short clock = 0;
int delta = 0;
if (clocks != null && value.isTouched()) {
@@ -229,19 +178,10 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
if (i >= size) {
return;
}
- setWeight(i, 0.f);
- if (covars != null) {
+ _setWeight(i, 0.f);
+ if(covars != null) {
setCovar(i, 1.f);
}
- if (sum_of_squared_gradients != null) {
- sum_of_squared_gradients[i] = 0.f;
- }
- if (sum_of_squared_delta_x != null) {
- sum_of_squared_delta_x[i] = 0.f;
- }
- if (sum_of_gradients != null) {
- sum_of_gradients[i] = 0.f;
- }
// avoid clock/delta
}
@@ -255,6 +195,13 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
}
@Override
+ public void setWeight(Object feature, float value) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ _setWeight(i, value);
+ }
+
+ @Override
public float getCovariance(Object feature) {
int i = HiveUtils.parseInt(feature);
if (i >= size) {
@@ -267,7 +214,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- setWeight(i, weight);
+ _setWeight(i, weight);
clocks[i] = clock;
deltaUpdates[i] = 0;
}
@@ -276,7 +223,7 @@ public final class SpaceEfficientDenseModel extends AbstractPredictionModel {
protected void _set(Object feature, float weight, float covar, short clock) {
int i = ((Integer) feature).intValue();
ensureCapacity(i);
- setWeight(i, weight);
+ _setWeight(i, weight);
setCovar(i, covar);
clocks[i] = clock;
deltaUpdates[i] = 0;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java
index aaab869..bab982f 100644
--- a/core/src/main/java/hivemall/model/SparseModel.java
+++ b/core/src/main/java/hivemall/model/SparseModel.java
@@ -36,6 +36,10 @@ public final class SparseModel extends AbstractPredictionModel {
private final boolean hasCovar;
private boolean clockEnabled;
+ public SparseModel(int size) {
+ this(size, false);
+ }
+
public SparseModel(int size, boolean hasCovar) {
super();
this.weights = new OpenHashMap<Object, IWeightValue>(size);
@@ -54,10 +58,6 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
- public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
- boolean sum_of_gradients) {}
-
- @Override
public void configureClock() {
this.clockEnabled = true;
}
@@ -131,6 +131,18 @@ public final class SparseModel extends AbstractPredictionModel {
}
@Override
+ public void setWeight(Object feature, float value) {
+ if(weights.containsKey(feature)) {
+ IWeightValue weight = weights.get(feature);
+ weight.set(value);
+ } else {
+ IWeightValue weight = new WeightValue(value);
+ weight.set(value);
+ weights.put(feature, weight);
+ }
+ }
+
+ @Override
public float getCovariance(final Object feature) {
IWeightValue v = weights.get(feature);
return v == null ? 1.f : v.getCovariance();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
index 99ee69c..87e89b6 100644
--- a/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
+++ b/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
@@ -63,12 +63,6 @@ public final class SynchronizedModelWrapper implements PredictionModel {
}
@Override
- public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
- boolean sum_of_gradients) {
- model.configureParams(sum_of_squared_gradients, sum_of_squared_delta_x, sum_of_gradients);
- }
-
- @Override
public void configureClock() {
model.configureClock();
}
@@ -157,6 +151,16 @@ public final class SynchronizedModelWrapper implements PredictionModel {
}
@Override
+ public void setWeight(Object feature, float value) {
+ try {
+ lock.lock();
+ model.setWeight(feature, value);
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
public float getCovariance(Object feature) {
try {
lock.lock();
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/WeightValue.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/WeightValue.java b/core/src/main/java/hivemall/model/WeightValue.java
index e6d98c6..b329374 100644
--- a/core/src/main/java/hivemall/model/WeightValue.java
+++ b/core/src/main/java/hivemall/model/WeightValue.java
@@ -77,15 +77,50 @@ public class WeightValue implements IWeightValue {
}
@Override
+ public void setSumOfSquaredGradients(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
public float getSumOfSquaredDeltaX() {
return 0.f;
}
@Override
+ public void setSumOfSquaredDeltaX(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
public float getSumOfGradients() {
return 0.f;
}
+ @Override
+ public void setSumOfGradients(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getM() {
+ return 0.f;
+ }
+
+ @Override
+ public void setM(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getV() {
+ return 0.f;
+ }
+
+ @Override
+ public void setV(float value) {
+ throw new UnsupportedOperationException();
+ }
+
/**
* @return whether touched in training or not
*/
@@ -137,7 +172,7 @@ public class WeightValue implements IWeightValue {
}
public static final class WeightValueParamsF1 extends WeightValue {
- private final float f1;
+ private float f1;
public WeightValueParamsF1(float weight, float f1) {
super(weight);
@@ -162,14 +197,19 @@ public class WeightValue implements IWeightValue {
return f1;
}
+ @Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
}
/**
* WeightValue with Sum of Squared Gradients
*/
public static final class WeightValueParamsF2 extends WeightValue {
- private final float f1;
- private final float f2;
+ private float f1;
+ private float f2;
public WeightValueParamsF2(float weight, float f1, float f2) {
super(weight);
@@ -198,15 +238,131 @@ public class WeightValue implements IWeightValue {
}
@Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
+ @Override
public final float getSumOfSquaredDeltaX() {
return f2;
}
@Override
+ public void setSumOfSquaredDeltaX(float value) {
+ this.f2 = value;
+ }
+
+ @Override
public float getSumOfGradients() {
return f2;
}
+ @Override
+ public void setSumOfGradients(float value) {
+ this.f2 = value;
+ }
+
+ @Override
+ public float getM() {
+ return f1;
+ }
+
+ @Override
+ public void setM(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public float getV() {
+ return f2;
+ }
+
+ @Override
+ public void setV(float value) {
+ this.f2 = value;
+ }
+
+ }
+
+ public static final class WeightValueParamsF3 extends WeightValue {
+ private float f1;
+ private float f2;
+ private float f3;
+
+ public WeightValueParamsF3(float weight, float f1, float f2, float f3) {
+ super(weight);
+ this.f1 = f1;
+ this.f2 = f2;
+ this.f3 = f3;
+ }
+
+ @Override
+ public WeightValueType getType() {
+ return WeightValueType.ParamsF3;
+ }
+
+ @Override
+ public float getFloatParams(@Nonnegative final int i) {
+ if(i == 1) {
+ return f1;
+ } else if(i == 2) {
+ return f2;
+ } else if (i == 3) {
+ return f3;
+ }
+ throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called");
+ }
+
+ @Override
+ public final float getSumOfSquaredGradients() {
+ return f1;
+ }
+
+ @Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public final float getSumOfSquaredDeltaX() {
+ return f2;
+ }
+
+ @Override
+ public void setSumOfSquaredDeltaX(float value) {
+ this.f2 = value;
+ }
+
+ @Override
+ public float getSumOfGradients() {
+ return f3;
+ }
+
+ @Override
+ public void setSumOfGradients(float value) {
+ this.f3 = value;
+ }
+
+ @Override
+ public float getM() {
+ return f1;
+ }
+
+ @Override
+ public void setM(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public float getV() {
+ return f2;
+ }
+
+ @Override
+ public void setV(float value) {
+ this.f2 = value;
+ }
+
}
public static final class WeightValueWithCovar extends WeightValue {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/model/WeightValueWithClock.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/model/WeightValueWithClock.java b/core/src/main/java/hivemall/model/WeightValueWithClock.java
index 249650a..9b31361 100644
--- a/core/src/main/java/hivemall/model/WeightValueWithClock.java
+++ b/core/src/main/java/hivemall/model/WeightValueWithClock.java
@@ -79,15 +79,50 @@ public class WeightValueWithClock implements IWeightValue {
}
@Override
+ public void setSumOfSquaredGradients(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
public float getSumOfSquaredDeltaX() {
return 0.f;
}
@Override
+ public void setSumOfSquaredDeltaX(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
public float getSumOfGradients() {
return 0.f;
}
+ @Override
+ public void setSumOfGradients(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getM() {
+ return 0.f;
+ }
+
+ @Override
+ public void setM(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getV() {
+ return 0.f;
+ }
+
+ @Override
+ public void setV(float value) {
+ throw new UnsupportedOperationException();
+ }
+
/**
* @return whether touched in training or not
*/
@@ -144,7 +179,7 @@ public class WeightValueWithClock implements IWeightValue {
* WeightValue with Sum of Squared Gradients
*/
public static final class WeightValueParamsF1Clock extends WeightValueWithClock {
- private final float f1;
+ private float f1;
public WeightValueParamsF1Clock(float value, float f1) {
super(value);
@@ -174,11 +209,16 @@ public class WeightValueWithClock implements IWeightValue {
return f1;
}
+ @Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
}
public static final class WeightValueParamsF2Clock extends WeightValueWithClock {
- private final float f1;
- private final float f2;
+ private float f1;
+ private float f2;
public WeightValueParamsF2Clock(float value, float f1, float f2) {
super(value);
@@ -213,15 +253,136 @@ public class WeightValueWithClock implements IWeightValue {
}
@Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public float getSumOfSquaredDeltaX() {
+ return f2;
+ }
+
+ @Override
+ public void setSumOfSquaredDeltaX(float value) {
+ this.f2 = value;
+ }
+
+ @Override
+ public float getSumOfGradients() {
+ return f2;
+ }
+
+ @Override
+ public void setSumOfGradients(float value) {
+ this.f2 = value;
+ }
+ @Override
+ public float getM() {
+ return f1;
+ }
+
+ @Override
+ public void setM(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public float getV() {
+ return f2;
+ }
+
+ @Override
+ public void setV(float value) {
+ this.f2 = value;
+ }
+
+ }
+
+ public static final class WeightValueParamsF3Clock extends WeightValueWithClock {
+ private float f1;
+ private float f2;
+ private float f3;
+
+ public WeightValueParamsF3Clock(float value, float f1, float f2, float f3) {
+ super(value);
+ this.f1 = f1;
+ this.f2 = f2;
+ this.f3 = f3;
+ }
+
+ public WeightValueParamsF3Clock(IWeightValue src) {
+ super(src);
+ this.f1 = src.getFloatParams(1);
+ this.f2 = src.getFloatParams(2);
+ this.f3 = src.getFloatParams(3);
+ }
+
+ @Override
+ public WeightValueType getType() {
+ return WeightValueType.ParamsF3;
+ }
+
+ @Override
+ public float getFloatParams(@Nonnegative final int i) {
+ if(i == 1) {
+ return f1;
+ } else if(i == 2) {
+ return f2;
+ } else if(i == 3) {
+ return f3;
+ }
+ throw new IllegalArgumentException("getFloatParams(" + i + ") should not be called");
+ }
+
+ @Override
+ public float getSumOfSquaredGradients() {
+ return f1;
+ }
+
+ @Override
+ public void setSumOfSquaredGradients(float value) {
+ this.f1 = value;
+ }
+
+ @Override
public float getSumOfSquaredDeltaX() {
return f2;
}
@Override
+ public void setSumOfSquaredDeltaX(float value) {
+ this.f2 = value;
+ }
+
+ @Override
public float getSumOfGradients() {
+ return f3;
+ }
+
+ @Override
+ public void setSumOfGradients(float value) {
+ this.f3 = value;
+ }
+ @Override
+ public float getM() {
+ return f1;
+ }
+
+ @Override
+ public void setM(float value) {
+ this.f1 = value;
+ }
+
+ @Override
+ public float getV() {
return f2;
}
+ @Override
+ public void setV(float value) {
+ this.f2 = value;
+ }
+
}
public static final class WeightValueWithCovarClock extends WeightValueWithClock {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
new file mode 100644
index 0000000..e2c5a10
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/DenseOptimizerFactory.java
@@ -0,0 +1,215 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+import java.util.Arrays;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import hivemall.optimizer.Optimizer.OptimizerBase;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.math.MathUtils;
+
+public final class DenseOptimizerFactory {
+ private static final Log logger = LogFactory.getLog(DenseOptimizerFactory.class);
+
+ @Nonnull
+ public static Optimizer create(int ndims, @Nonnull Map<String, String> options) {
+ final String optimizerName = options.get("optimizer");
+ if(optimizerName != null) {
+ OptimizerBase optimizerImpl;
+ if(optimizerName.toLowerCase().equals("sgd")) {
+ optimizerImpl = new Optimizer.SGD(options);
+ } else if(optimizerName.toLowerCase().equals("adadelta")) {
+ optimizerImpl = new AdaDelta(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adagrad")) {
+ optimizerImpl = new AdaGrad(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adam")) {
+ optimizerImpl = new Adam(ndims, options);
+ } else {
+ throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
+ }
+
+ logger.info("set " + optimizerImpl.getClass().getSimpleName()
+ + " as an optimizer: " + options);
+
+ // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
+ if(options.get("regularization") != null
+ && options.get("regularization").toLowerCase().equals("rda")) {
+ optimizerImpl = new RDA(ndims, optimizerImpl, options);
+ }
+
+ return optimizerImpl;
+ }
+ throw new IllegalArgumentException("`optimizer` not defined");
+ }
+
+ @NotThreadSafe
+ static final class AdaDelta extends Optimizer.AdaDelta {
+
+ private final IWeightValue weightValueReused;
+
+ private float[] sum_of_squared_gradients;
+ private float[] sum_of_squared_delta_x;
+
+ public AdaDelta(int ndims, Map<String, String> options) {
+ super(options);
+ this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f);
+ this.sum_of_squared_gradients = new float[ndims];
+ this.sum_of_squared_delta_x = new float[ndims];
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weightValueReused.set(weight);
+ weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]);
+ weightValueReused.setSumOfSquaredDeltaX(sum_of_squared_delta_x[i]);
+ computeUpdateValue(weightValueReused, gradient);
+ sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients();
+ sum_of_squared_delta_x[i] = weightValueReused.getSumOfSquaredDeltaX();
+ return weightValueReused.get();
+ }
+
+ private void ensureCapacity(final int index) {
+ if(index >= sum_of_squared_gradients.length) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ this.sum_of_squared_delta_x = Arrays.copyOf(sum_of_squared_delta_x, newSize);
+ }
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class AdaGrad extends Optimizer.AdaGrad {
+
+ private final IWeightValue weightValueReused;
+
+ private float[] sum_of_squared_gradients;
+
+ public AdaGrad(int ndims, Map<String, String> options) {
+ super(options);
+ this.weightValueReused = new WeightValue.WeightValueParamsF1(0.f, 0.f);
+ this.sum_of_squared_gradients = new float[ndims];
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weightValueReused.set(weight);
+ weightValueReused.setSumOfSquaredGradients(sum_of_squared_gradients[i]);
+ computeUpdateValue(weightValueReused, gradient);
+ sum_of_squared_gradients[i] = weightValueReused.getSumOfSquaredGradients();
+ return weightValueReused.get();
+ }
+
+ private void ensureCapacity(final int index) {
+ if(index >= sum_of_squared_gradients.length) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ this.sum_of_squared_gradients = Arrays.copyOf(sum_of_squared_gradients, newSize);
+ }
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class Adam extends Optimizer.Adam {
+
+ private final IWeightValue weightValueReused;
+
+ private float[] val_m;
+ private float[] val_v;
+
+ public Adam(int ndims, Map<String, String> options) {
+ super(options);
+ this.weightValueReused = new WeightValue.WeightValueParamsF2(0.f, 0.f, 0.f);
+ this.val_m = new float[ndims];
+ this.val_v = new float[ndims];
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weightValueReused.set(weight);
+ weightValueReused.setM(val_m[i]);
+ weightValueReused.setV(val_v[i]);
+ computeUpdateValue(weightValueReused, gradient);
+ val_m[i] = weightValueReused.getM();
+ val_v[i] = weightValueReused.getV();
+ return weightValueReused.get();
+ }
+
+ private void ensureCapacity(final int index) {
+ if(index >= val_m.length) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ this.val_m = Arrays.copyOf(val_m, newSize);
+ this.val_v = Arrays.copyOf(val_v, newSize);
+ }
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class RDA extends Optimizer.RDA {
+
+ private final IWeightValue weightValueReused;
+
+ private float[] sum_of_gradients;
+
+ public RDA(int ndims, final OptimizerBase optimizerImpl, Map<String, String> options) {
+ super(optimizerImpl, options);
+ this.weightValueReused = new WeightValue.WeightValueParamsF3(0.f, 0.f, 0.f, 0.f);
+ this.sum_of_gradients = new float[ndims];
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ int i = HiveUtils.parseInt(feature);
+ ensureCapacity(i);
+ weightValueReused.set(weight);
+ weightValueReused.setSumOfGradients(sum_of_gradients[i]);
+ computeUpdateValue(weightValueReused, gradient);
+ sum_of_gradients[i] = weightValueReused.getSumOfGradients();
+ return weightValueReused.get();
+ }
+
+ private void ensureCapacity(final int index) {
+ if(index >= sum_of_gradients.length) {
+ int bits = MathUtils.bitsRequired(index);
+ int newSize = (1 << bits) + 1;
+ this.sum_of_gradients = Arrays.copyOf(sum_of_gradients, newSize);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/EtaEstimator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/EtaEstimator.java b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
new file mode 100644
index 0000000..ac1d112
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
@@ -0,0 +1,191 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import hivemall.utils.lang.NumberUtils;
+import hivemall.utils.lang.Primitives;
+
+import java.util.Map;
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+
+public abstract class EtaEstimator {
+
+ protected final float eta0;
+
+ public EtaEstimator(float eta0) {
+ this.eta0 = eta0;
+ }
+
+ public float eta0() {
+ return eta0;
+ }
+
+ public abstract float eta(long t);
+
+ public void update(@Nonnegative float multipler) {}
+
+ public static final class FixedEtaEstimator extends EtaEstimator {
+
+ public FixedEtaEstimator(float eta) {
+ super(eta);
+ }
+
+ @Override
+ public float eta(long t) {
+ return eta0;
+ }
+
+ }
+
+ public static final class SimpleEtaEstimator extends EtaEstimator {
+
+ private final float finalEta;
+ private final double total_steps;
+
+ public SimpleEtaEstimator(float eta0, long total_steps) {
+ super(eta0);
+ this.finalEta = (float) (eta0 / 2.d);
+ this.total_steps = total_steps;
+ }
+
+ @Override
+ public float eta(final long t) {
+ if (t > total_steps) {
+ return finalEta;
+ }
+ return (float) (eta0 / (1.d + (t / total_steps)));
+ }
+
+ }
+
+ public static final class InvscalingEtaEstimator extends EtaEstimator {
+
+ private final double power_t;
+
+ public InvscalingEtaEstimator(float eta0, double power_t) {
+ super(eta0);
+ this.power_t = power_t;
+ }
+
+ @Override
+ public float eta(final long t) {
+ return (float) (eta0 / Math.pow(t, power_t));
+ }
+
+ }
+
+ /**
+ * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic
+ * gradient descent, KDD 2011.
+ */
+ public static final class AdjustingEtaEstimator extends EtaEstimator {
+
+ private float eta;
+
+ public AdjustingEtaEstimator(float eta) {
+ super(eta);
+ this.eta = eta;
+ }
+
+ @Override
+ public float eta(long t) {
+ return eta;
+ }
+
+ @Override
+ public void update(@Nonnegative float multipler) {
+ float newEta = eta * multipler;
+ if (!NumberUtils.isFinite(newEta)) {
+ // avoid NaN or INFINITY
+ return;
+ }
+ this.eta = Math.min(eta0, newEta); // never be larger than eta0
+ }
+
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
+ return get(cl, 0.1f);
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0)
+ throws UDFArgumentException {
+ if (cl == null) {
+ return new InvscalingEtaEstimator(defaultEta0, 0.1d);
+ }
+
+ if (cl.hasOption("boldDriver")) {
+ float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f);
+ return new AdjustingEtaEstimator(eta);
+ }
+
+ String etaValue = cl.getOptionValue("eta");
+ if (etaValue != null) {
+ float eta = Float.parseFloat(etaValue);
+ return new FixedEtaEstimator(eta);
+ }
+
+ float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
+ if (cl.hasOption("t")) {
+ long t = Long.parseLong(cl.getOptionValue("t"));
+ return new SimpleEtaEstimator(eta0, t);
+ }
+
+ double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d);
+ return new InvscalingEtaEstimator(eta0, power_t);
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nonnull final Map<String, String> options)
+ throws IllegalArgumentException {
+ final String etaName = options.get("eta");
+ if(etaName == null) {
+ return new FixedEtaEstimator(1.f);
+ }
+ float eta0 = 0.1f;
+ if(options.containsKey("eta0")) {
+ eta0 = Float.parseFloat(options.get("eta0"));
+ }
+ if(etaName.toLowerCase().equals("fixed")) {
+ return new FixedEtaEstimator(eta0);
+ } else if(etaName.toLowerCase().equals("simple")) {
+ long t = 10000;
+ if(options.containsKey("t")) {
+ t = Long.parseLong(options.get("t"));
+ }
+ return new SimpleEtaEstimator(eta0, t);
+ } else if(etaName.toLowerCase().equals("inverse")) {
+ double power_t = 0.1;
+ if(options.containsKey("power_t")) {
+ power_t = Double.parseDouble(options.get("power_t"));
+ }
+ return new InvscalingEtaEstimator(eta0, power_t);
+ } else {
+ throw new IllegalArgumentException("Unsupported ETA name: " + etaName);
+ }
+ }
+
+}
[16/50] [abbrv] incubator-hivemall git commit: refine chi2
Posted by my...@apache.org.
refine chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a16a3fde
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a16a3fde
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a16a3fde
Branch: refs/heads/JIRA-22/pr-385
Commit: a16a3fde844ba381dee7eb1e9608ddc2dcfb96fc
Parents: 6dc2344
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 13:10:18 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:35:33 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 40 +++++++------
.../java/hivemall/utils/math/StatsUtils.java | 62 +++++++++++---------
2 files changed, 58 insertions(+), 44 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index e2b7494..951aeeb 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -50,6 +50,12 @@ public class ChiSquareUDF extends GenericUDF {
private ListObjectInspector expectedRowOI;
private PrimitiveObjectInspector expectedElOI;
+ private int nFeatures = -1;
+ private double[] observedRow = null; // to reuse
+ private double[] expectedRow = null; // to reuse
+ private double[][] observed = null; // shape = (#features, #classes)
+ private double[][] expected = null; // shape = (#features, #classes)
+
@Override
public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
if (OIs.length != 2) {
@@ -75,12 +81,12 @@ public class ChiSquareUDF extends GenericUDF {
expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
- List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(
- Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
+ Arrays.asList("chi2", "pvalue"), fieldOIs);
}
@Override
@@ -93,28 +99,28 @@ public class ChiSquareUDF extends GenericUDF {
final int nClasses = observedObj.size();
Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
- int nFeatures = -1;
- double[] observedRow = null; // to reuse
- double[] expectedRow = null; // to reuse
- double[][] observed = null; // shape = (#features, #classes)
- double[][] expected = null; // shape = (#features, #classes)
-
// explode and transpose matrix
for (int i = 0; i < nClasses; i++) {
- if (i == 0) {
+ final Object observedObjRow = observedObj.get(i);
+ final Object expectedObjRow = observedObj.get(i);
+
+ Preconditions.checkNotNull(observedObjRow);
+ Preconditions.checkNotNull(expectedObjRow);
+
+ if (observedRow == null) {
// init
- observedRow = HiveUtils.asDoubleArray(observedObj.get(i), observedRowOI,
- observedElOI, false);
- expectedRow = HiveUtils.asDoubleArray(expectedObj.get(i), expectedRowOI,
- expectedElOI, false);
+ observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI,
+ false);
+ expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI,
+ false);
nFeatures = observedRow.length;
observed = new double[nFeatures][nClasses];
expected = new double[nFeatures][nClasses];
} else {
- HiveUtils.toDoubleArray(observedObj.get(i), observedRowOI, observedElOI,
- observedRow, false);
- HiveUtils.toDoubleArray(expectedObj.get(i), expectedRowOI, expectedElOI,
- expectedRow, false);
+ HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow,
+ false);
+ HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow,
+ false);
}
for (int j = 0; j < nFeatures; j++) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index d3b25c7..e255b84 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -23,11 +23,15 @@ import hivemall.utils.lang.Preconditions;
import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
import java.util.AbstractMap;
import java.util.Map;
@@ -194,54 +198,59 @@ public final class StatsUtils {
}
/**
- * @param observed mean vector whose value is observed
- * @param expected mean vector whose value is expected
+ * @param observed means non-negative vector
+ * @param expected means positive vector
* @return chi2 value
*/
public static double chiSquare(@Nonnull final double[] observed,
@Nonnull final double[] expected) {
- Preconditions.checkArgument(observed.length == expected.length);
+ if (observed.length < 2) {
+ throw new DimensionMismatchException(observed.length, 2);
+ }
+ if (expected.length != observed.length) {
+ throw new DimensionMismatchException(observed.length, expected.length);
+ }
+ MathArrays.checkPositive(expected);
+ for (double d : observed) {
+ if (d < 0.d) {
+ throw new NotPositiveException(d);
+ }
+ }
double sumObserved = 0.d;
double sumExpected = 0.d;
-
- for (int ratio = 0; ratio < observed.length; ++ratio) {
- sumObserved += observed[ratio];
- sumExpected += expected[ratio];
+ for (int i = 0; i < observed.length; i++) {
+ sumObserved += observed[i];
+ sumExpected += expected[i];
}
-
- double var15 = 1.d;
+ double ratio = 1.d;
boolean rescale = false;
- if (Math.abs(sumObserved - sumExpected) > 1.e-5) {
- var15 = sumObserved / sumExpected;
+ if (FastMath.abs(sumObserved - sumExpected) > 10e-6) {
+ ratio = sumObserved / sumExpected;
rescale = true;
}
-
double sumSq = 0.d;
-
- for (int i = 0; i < observed.length; ++i) {
- double dev;
+ for (int i = 0; i < observed.length; i++) {
if (rescale) {
- dev = observed[i] - var15 * expected[i];
- sumSq += dev * dev / (var15 * expected[i]);
+ final double dev = observed[i] - ratio * expected[i];
+ sumSq += dev * dev / (ratio * expected[i]);
} else {
- dev = observed[i] - expected[i];
+ final double dev = observed[i] - expected[i];
sumSq += dev * dev / expected[i];
}
}
-
return sumSq;
}
/**
- * @param observed means vector whose value is observed
- * @param expected means vector whose value is expected
+ * @param observed means non-negative vector
+ * @param expected means positive vector
* @return p value
*/
public static double chiSquareTest(@Nonnull final double[] observed,
@Nonnull final double[] expected) {
- ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
- (double) expected.length - 1.d);
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(
+ expected.length - 1.d);
return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected));
}
@@ -249,8 +258,8 @@ public final class StatsUtils {
* This method offers effective calculation for multiple entries rather than calculation
* individually
*
- * @param observeds means matrix whose values are observed
- * @param expecteds means matrix
+ * @param observeds means non-negative matrix
+ * @param expecteds means positive matrix
* @return (chi2 value[], p value[])
*/
public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds,
@@ -260,8 +269,7 @@ public final class StatsUtils {
final int len = expecteds.length;
final int lenOfEach = expecteds[0].length;
- final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null,
- (double) lenOfEach - 1.d);
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d);
final double[] chi2s = new double[len];
final double[] ps = new double[len];
[11/50] [abbrv] incubator-hivemall git commit: Add optimizer
implementations
Posted by my...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/LossFunctions.java b/core/src/main/java/hivemall/optimizer/LossFunctions.java
new file mode 100644
index 0000000..d11be9b
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@ -0,0 +1,467 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import hivemall.utils.math.MathUtils;
+
+/**
+ * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions
+ */
+public final class LossFunctions {
+
+ public enum LossType {
+ SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss
+ }
+
+ public static LossFunction getLossFunction(String type) {
+ if ("SquaredLoss".equalsIgnoreCase(type)) {
+ return new SquaredLoss();
+ } else if ("LogLoss".equalsIgnoreCase(type)) {
+ return new LogLoss();
+ } else if ("HingeLoss".equalsIgnoreCase(type)) {
+ return new HingeLoss();
+ } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
+ return new SquaredHingeLoss();
+ } else if ("QuantileLoss".equalsIgnoreCase(type)) {
+ return new QuantileLoss();
+ } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
+ return new EpsilonInsensitiveLoss();
+ }
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+
+ public static LossFunction getLossFunction(LossType type) {
+ switch (type) {
+ case SquaredLoss:
+ return new SquaredLoss();
+ case LogLoss:
+ return new LogLoss();
+ case HingeLoss:
+ return new HingeLoss();
+ case SquaredHingeLoss:
+ return new SquaredHingeLoss();
+ case QuantileLoss:
+ return new QuantileLoss();
+ case EpsilonInsensitiveLoss:
+ return new EpsilonInsensitiveLoss();
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+ }
+
+ public interface LossFunction {
+
+ /**
+ * Evaluate the loss function.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The loss evaluated at `p` and `y`.
+ */
+ public float loss(float p, float y);
+
+ public double loss(double p, double y);
+
+ /**
+ * Evaluate the derivative of the loss function with respect to the prediction `p`.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The derivative of the loss function w.r.t. `p`.
+ */
+ public float dloss(float p, float y);
+
+ public boolean forBinaryClassification();
+
+ public boolean forRegression();
+
+ }
+
+ public static abstract class BinaryLoss implements LossFunction {
+
+ protected static void checkTarget(float y) {
+ if (!(y == 1.f || y == -1.f)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ protected static void checkTarget(double y) {
+ if (!(y == 1.d || y == -1.d)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ @Override
+ public boolean forBinaryClassification() {
+ return true;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return false;
+ }
+ }
+
+ public static abstract class RegressionLoss implements LossFunction {
+
+ @Override
+ public boolean forBinaryClassification() {
+ return false;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return true;
+ }
+
+ }
+
+ /**
+ * Squared loss for regression problems.
+ *
+ * If you're trying to minimize the mean error, use squared-loss.
+ */
+ public static final class SquaredLoss extends RegressionLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ return p - y; // 2 (p - y) / 2
+ }
+ }
+
+ /**
+ * Logistic regression loss for binary classification with y in {-1, 1}.
+ */
+ public static final class LogLoss extends BinaryLoss {
+
+ /**
+ * <code>logloss(p,y) = log(1+exp(-p*y))</code>
+ */
+ @Override
+ public float loss(float p, float y) {
+ checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z) * -y;
+ }
+ if (z < -18.f) {
+ return -y;
+ }
+ return -y / ((float) Math.exp(z) + 1.f);
+ }
+ }
+
+ /**
+ * Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class HingeLoss extends BinaryLoss {
+
+ private float threshold;
+
+ public HingeLoss() {
+ this(1.f);
+ }
+
+ /**
+ * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM.
+ * When threshold=0.0, one gets the loss used by the Perceptron.
+ */
+ public HingeLoss(float threshold) {
+ this.threshold = threshold;
+ }
+
+ public void setThreshold(float threshold) {
+ this.threshold = threshold;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = hingeLoss(p, y, threshold);
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? -y : 0.f;
+ }
+ }
+
+ /**
+ * Squared Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class SquaredHingeLoss extends BinaryLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float d = 1 - (y * p);
+ return (d > 0.f) ? -2.f * d * y : 0.f;
+ }
+
+ }
+
+ /**
+ * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase
+ * as long as you get the relative order correct.
+ *
+ * @link http://en.wikipedia.org/wiki/Quantile_regression
+ */
+ public static final class QuantileLoss extends RegressionLoss {
+
+ private float tau;
+
+ public QuantileLoss() {
+ this.tau = 0.5f;
+ }
+
+ public QuantileLoss(float tau) {
+ setTau(tau);
+ }
+
+ public void setTau(float tau) {
+ if (tau <= 0 || tau >= 1.0) {
+ throw new IllegalArgumentException("tau must be in range (0, 1): " + tau);
+ }
+ this.tau = tau;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float e = y - p;
+ if (e > 0.f) {
+ return tau * e;
+ } else {
+ return -(1.f - tau) * e;
+ }
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double e = y - p;
+ if (e > 0.d) {
+ return tau * e;
+ } else {
+ return -(1.d - tau) * e;
+ }
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float e = y - p;
+ if (e == 0.f) {
+ return 0.f;
+ }
+ return (e > 0.f) ? -tau : (1.f - tau);
+ }
+
+ }
+
+ /**
+ * Epsilon-Insensitive loss used by Support Vector Regression (SVR).
+ * <code>loss = max(0, |y - p| - epsilon)</code>
+ */
+ public static final class EpsilonInsensitiveLoss extends RegressionLoss {
+
+ private float epsilon;
+
+ public EpsilonInsensitiveLoss() {
+ this(0.1f);
+ }
+
+ public EpsilonInsensitiveLoss(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ public void setEpsilon(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ if ((y - p) > epsilon) {// real value > predicted value - epsilon
+ return -1.f;
+ }
+ if ((p - y) > epsilon) {// real value < predicted value - epsilon
+ return 1.f;
+ }
+ return 0.f;
+ }
+
+ }
+
+ public static float logisticLoss(final float target, final float predicted) {
+ if (predicted > -100.d) {
+ return target - (float) MathUtils.sigmoid(predicted);
+ } else {
+ return target;
+ }
+ }
+
+ public static float logLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ public static double logLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ public static float squaredLoss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ public static double squaredLoss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ public static float hingeLoss(final float p, final float y, final float threshold) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ return threshold - z;
+ }
+
+ public static double hingeLoss(final double p, final double y, final double threshold) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ return threshold - z;
+ }
+
+ public static float hingeLoss(float p, float y) {
+ return hingeLoss(p, y, 1.f);
+ }
+
+ public static double hingeLoss(double p, double y) {
+ return hingeLoss(p, y, 1.d);
+ }
+
+ public static float squaredHingeLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ float d = 1.f - z;
+ return (d > 0.f) ? (d * d) : 0.f;
+ }
+
+ public static double squaredHingeLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ double d = 1.d - z;
+ return (d > 0.d) ? d * d : 0.d;
+ }
+
+ /**
+ * Math.abs(target - predicted) - epsilon
+ */
+ public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
+ return Math.abs(target - predicted) - epsilon;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Optimizer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Optimizer.java b/core/src/main/java/hivemall/optimizer/Optimizer.java
new file mode 100644
index 0000000..863536c
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/Optimizer.java
@@ -0,0 +1,246 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import java.util.Map;
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import hivemall.model.WeightValue;
+import hivemall.model.IWeightValue;
+
+public interface Optimizer {
+
+ /**
+ * Update the weights of models thru this interface.
+ */
+ float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient);
+
+ // Count up #step to tune learning rate
+ void proceedStep();
+
+ static abstract class OptimizerBase implements Optimizer {
+
+ protected final EtaEstimator etaImpl;
+ protected final Regularization regImpl;
+
+ protected int numStep = 1;
+
+ public OptimizerBase(final Map<String, String> options) {
+ this.etaImpl = EtaEstimator.get(options);
+ this.regImpl = Regularization.get(options);
+ }
+
+ @Override
+ public void proceedStep() {
+ numStep++;
+ }
+
+ // Directly update a given `weight` in terms of performance
+ protected void computeUpdateValue(
+ @Nonnull final IWeightValue weight, float gradient) {
+ float delta = computeUpdateValueImpl(weight, regImpl.regularize(weight.get(), gradient));
+ weight.set(weight.get() - etaImpl.eta(numStep) * delta);
+ }
+
+ // Compute a delta to update
+ protected float computeUpdateValueImpl(
+ @Nonnull final IWeightValue weight, float gradient) {
+ return gradient;
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class SGD extends OptimizerBase {
+
+ private final IWeightValue weightValueReused;
+
+ public SGD(final Map<String, String> options) {
+ super(options);
+ this.weightValueReused = new WeightValue(0.f);
+ }
+
+ @Override
+ public float computeUpdatedValue(
+ @Nonnull Object feature, float weight, float gradient) {
+ computeUpdateValue(weightValueReused, gradient);
+ return weightValueReused.get();
+ }
+
+ }
+
+ static abstract class AdaDelta extends OptimizerBase {
+
+ private final float decay;
+ private final float eps;
+ private final float scale;
+
+ public AdaDelta(Map<String, String> options) {
+ super(options);
+ float decay = 0.95f;
+ float eps = 1e-6f;
+ float scale = 100.0f;
+ if(options.containsKey("decay")) {
+ decay = Float.parseFloat(options.get("decay"));
+ }
+ if(options.containsKey("eps")) {
+ eps = Float.parseFloat(options.get("eps"));
+ }
+ if(options.containsKey("scale")) {
+ scale = Float.parseFloat(options.get("scale"));
+ }
+ this.decay = decay;
+ this.eps = eps;
+ this.scale = scale;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float old_scaled_sum_sqgrad = weight.getSumOfSquaredGradients();
+ float old_sum_squared_delta_x = weight.getSumOfSquaredDeltaX();
+ float new_scaled_sum_sqgrad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * gradient * (gradient / scale));
+ float delta = (float) Math.sqrt((old_sum_squared_delta_x + eps) / (new_scaled_sum_sqgrad * scale + eps)) * gradient;
+ float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x) + ((1.f - decay) * delta * delta);
+ weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
+ weight.setSumOfSquaredDeltaX(new_sum_squared_delta_x);
+ return delta;
+ }
+
+ }
+
+ static abstract class AdaGrad extends OptimizerBase {
+
+ private final float eps;
+ private final float scale;
+
+ public AdaGrad(Map<String, String> options) {
+ super(options);
+ float eps = 1.0f;
+ float scale = 100.0f;
+ if(options.containsKey("eps")) {
+ eps = Float.parseFloat(options.get("eps"));
+ }
+ if(options.containsKey("scale")) {
+ scale = Float.parseFloat(options.get("scale"));
+ }
+ this.eps = eps;
+ this.scale = scale;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float new_scaled_sum_sqgrad = weight.getSumOfSquaredGradients() + gradient * (gradient / scale);
+ float delta = gradient / ((float) Math.sqrt(new_scaled_sum_sqgrad * scale) + eps);
+ weight.setSumOfSquaredGradients(new_scaled_sum_sqgrad);
+ return delta;
+ }
+
+ }
+
+ /**
+ * Adam, an algorithm for first-order gradient-based optimization of stochastic objective
+ * functions, based on adaptive estimates of lower-order moments.
+ *
+ * - D. P. Kingma and J. L. Ba: "ADAM: A Method for Stochastic Optimization." arXiv preprint arXiv:1412.6980v8, 2014.
+ */
+ static abstract class Adam extends OptimizerBase {
+
+ private final float beta;
+ private final float gamma;
+ private final float eps_hat;
+
+ public Adam(Map<String, String> options) {
+ super(options);
+ float beta = 0.9f;
+ float gamma = 0.999f;
+ float eps_hat = 1e-8f;
+ if(options.containsKey("beta")) {
+ beta = Float.parseFloat(options.get("beta"));
+ }
+ if(options.containsKey("gamma")) {
+ gamma = Float.parseFloat(options.get("gamma"));
+ }
+ if(options.containsKey("eps_hat")) {
+ eps_hat = Float.parseFloat(options.get("eps_hat"));
+ }
+ this.beta = beta;
+ this.gamma = gamma;
+ this.eps_hat = eps_hat;
+ }
+
+ @Override
+ protected float computeUpdateValueImpl(@Nonnull final IWeightValue weight, float gradient) {
+ float val_m = beta * weight.getM() + (1.f - beta) * gradient;
+ float val_v = gamma * weight.getV() + (float) ((1.f - gamma) * Math.pow(gradient, 2.0));
+ float val_m_hat = val_m / (float) (1.f - Math.pow(beta, numStep));
+ float val_v_hat = val_v / (float) (1.f - Math.pow(gamma, numStep));
+ float delta = val_m_hat / (float) (Math.sqrt(val_v_hat) + eps_hat);
+ weight.setM(val_m);
+ weight.setV(val_v);
+ return delta;
+ }
+
+ }
+
+ static abstract class RDA extends OptimizerBase {
+
+ private final OptimizerBase optimizerImpl;
+
+ private final float lambda;
+
+ public RDA(final OptimizerBase optimizerImpl, Map<String, String> options) {
+ super(options);
+ // We assume `optimizerImpl` has the `AdaGrad` implementation only
+ if(!(optimizerImpl instanceof AdaGrad)) {
+ throw new IllegalArgumentException(
+ optimizerImpl.getClass().getSimpleName()
+ + " currently does not support RDA regularization");
+ }
+ float lambda = 1e-6f;
+ if(options.containsKey("lambda")) {
+ lambda = Float.parseFloat(options.get("lambda"));
+ }
+ this.optimizerImpl = optimizerImpl;
+ this.lambda = lambda;
+ }
+
+ @Override
+ protected void computeUpdateValue(@Nonnull final IWeightValue weight, float gradient) {
+ float new_sum_grad = weight.getSumOfGradients() + gradient;
+ // sign(u_{t,i})
+ float sign = (new_sum_grad > 0.f)? 1.f : -1.f;
+ // |u_{t,i}|/t - \lambda
+ float meansOfGradients = (sign * new_sum_grad / numStep) - lambda;
+ if(meansOfGradients < 0.f) {
+ // x_{t,i} = 0
+ weight.set(0.f);
+ weight.setSumOfSquaredGradients(0.f);
+ weight.setSumOfGradients(0.f);
+ } else {
+ // x_{t,i} = -sign(u_{t,i}) * \frac{\eta t}{\sqrt{G_{t,ii}}}(|u_{t,i}|/t - \lambda)
+ float new_weight = -1.f * sign * etaImpl.eta(numStep) * numStep * optimizerImpl.computeUpdateValueImpl(weight, meansOfGradients);
+ weight.set(new_weight);
+ weight.setSumOfGradients(new_sum_grad);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/Regularization.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/Regularization.java b/core/src/main/java/hivemall/optimizer/Regularization.java
new file mode 100644
index 0000000..ce1ef7f
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/Regularization.java
@@ -0,0 +1,99 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import javax.annotation.Nonnull;
+import java.util.Map;
+
+public abstract class Regularization {
+
+ protected final float lambda;
+
+ public Regularization(final Map<String, String> options) {
+ float lambda = 1e-6f;
+ if(options.containsKey("lambda")) {
+ lambda = Float.parseFloat(options.get("lambda"));
+ }
+ this.lambda = lambda;
+ }
+
+ abstract float regularize(float weight, float gradient);
+
+ public static final class PassThrough extends Regularization {
+
+ public PassThrough(final Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient;
+ }
+
+ }
+
+ public static final class L1 extends Regularization {
+
+ public L1(Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient + lambda * (weight > 0.f? 1.f : -1.f);
+ }
+
+ }
+
+ public static final class L2 extends Regularization {
+
+ public L2(final Map<String, String> options) {
+ super(options);
+ }
+
+ @Override
+ public float regularize(float weight, float gradient) {
+ return gradient + lambda * weight;
+ }
+
+ }
+
+ @Nonnull
+ public static Regularization get(@Nonnull final Map<String, String> options)
+ throws IllegalArgumentException {
+ final String regName = options.get("regularization");
+ if (regName == null) {
+ return new PassThrough(options);
+ }
+ if(regName.toLowerCase().equals("no")) {
+ return new PassThrough(options);
+ } else if(regName.toLowerCase().equals("l1")) {
+ return new L1(options);
+ } else if(regName.toLowerCase().equals("l2")) {
+ return new L2(options);
+ } else if(regName.toLowerCase().equals("rda")) {
+ // Return `PassThrough` because we need special handling for RDA.
+ // See an implementation of `Optimizer#RDA`.
+ return new PassThrough(options);
+ } else {
+ throw new IllegalArgumentException("Unsupported regularization name: " + regName);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
new file mode 100644
index 0000000..a74d0da
--- /dev/null
+++ b/core/src/main/java/hivemall/optimizer/SparseOptimizerFactory.java
@@ -0,0 +1,171 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import javax.annotation.Nonnull;
+import javax.annotation.concurrent.NotThreadSafe;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import hivemall.optimizer.Optimizer.OptimizerBase;
+import hivemall.model.IWeightValue;
+import hivemall.model.WeightValue;
+import hivemall.utils.collections.OpenHashMap;
+
+public final class SparseOptimizerFactory {
+ private static final Log logger = LogFactory.getLog(SparseOptimizerFactory.class);
+
+ @Nonnull
+ public static Optimizer create(int ndims, @Nonnull Map<String, String> options) {
+ final String optimizerName = options.get("optimizer");
+ if(optimizerName != null) {
+ OptimizerBase optimizerImpl;
+ if(optimizerName.toLowerCase().equals("sgd")) {
+ optimizerImpl = new Optimizer.SGD(options);
+ } else if(optimizerName.toLowerCase().equals("adadelta")) {
+ optimizerImpl = new AdaDelta(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adagrad")) {
+ optimizerImpl = new AdaGrad(ndims, options);
+ } else if(optimizerName.toLowerCase().equals("adam")) {
+ optimizerImpl = new Adam(ndims, options);
+ } else {
+ throw new IllegalArgumentException("Unsupported optimizer name: " + optimizerName);
+ }
+
+ logger.info("set " + optimizerImpl.getClass().getSimpleName()
+ + " as an optimizer: " + options);
+
+ // If a regularization type is "RDA", wrap the optimizer with `Optimizer#RDA`.
+ if(options.get("regularization") != null
+ && options.get("regularization").toLowerCase().equals("rda")) {
+ optimizerImpl = new RDA(ndims, optimizerImpl, options);
+ }
+
+ return optimizerImpl;
+ }
+ throw new IllegalArgumentException("`optimizer` not defined");
+ }
+
+ @NotThreadSafe
+ static final class AdaDelta extends Optimizer.AdaDelta {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public AdaDelta(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class AdaGrad extends Optimizer.AdaGrad {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public AdaGrad(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class Adam extends Optimizer.Adam {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public Adam(int size, Map<String, String> options) {
+ super(options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+ @NotThreadSafe
+ static final class RDA extends Optimizer.RDA {
+
+ private final OpenHashMap<Object, IWeightValue> auxWeights;
+
+ public RDA(int size, OptimizerBase optimizerImpl, Map<String, String> options) {
+ super(optimizerImpl, options);
+ this.auxWeights = new OpenHashMap<Object, IWeightValue>(size);
+ }
+
+ @Override
+ public float computeUpdatedValue(@Nonnull Object feature, float weight, float gradient) {
+ IWeightValue auxWeight;
+ if(auxWeights.containsKey(feature)) {
+ auxWeight = auxWeights.get(feature);
+ auxWeight.set(weight);
+ } else {
+ auxWeight = new WeightValue.WeightValueParamsF2(weight, 0.f, 0.f);
+ auxWeights.put(feature, auxWeight);
+ }
+ computeUpdateValue(auxWeight, gradient);
+ return auxWeight.get();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
index b81a4bf..0c964c8 100644
--- a/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
index e807340..50dc9b5 100644
--- a/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
@@ -18,123 +18,14 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
-import hivemall.model.FeatureValue;
-import hivemall.model.IWeightValue;
-import hivemall.model.WeightValue.WeightValueParamsF2;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
/**
* ADADELTA: AN ADAPTIVE LEARNING RATE METHOD.
*/
-@Description(
- name = "train_adadelta_regr",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class AdaDeltaUDTF extends RegressionBaseUDTF {
-
- private float decay;
- private float eps;
- private float scaling;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "AdaDeltaUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- StructObjectInspector oi = super.initialize(argOIs);
- model.configureParams(true, true, false);
- return oi;
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
- opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1e-6]");
- opts.addOption("scale", true,
- "Internal scaling/descaling factor for cumulative weights [100]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
- if (cl == null) {
- this.decay = 0.95f;
- this.eps = 1e-6f;
- this.scaling = 100f;
- } else {
- this.decay = Primitives.parseFloat(cl.getOptionValue("decay"), 0.95f);
- this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1E-6f);
- this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
- }
- return cl;
- }
-
- @Override
- protected final void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
- float gradient = LossFunctions.logisticLoss(target, predicted);
- onlineUpdate(features, gradient);
- }
-
- @Override
- protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
- final float g_g = gradient * (gradient / scaling);
-
- for (FeatureValue f : features) {// w[i] += y * x[i]
- if (f == null) {
- continue;
- }
- Object x = f.getFeature();
- float xi = f.getValueAsFloat();
-
- IWeightValue old_w = model.get(x);
- IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
- model.set(x, new_w);
- }
- }
-
- @Nonnull
- protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
- final float gradient, final float g_g) {
- float old_w = 0.f;
- float old_scaled_sum_sqgrad = 0.f;
- float old_sum_squared_delta_x = 0.f;
- if (old != null) {
- old_w = old.get();
- old_scaled_sum_sqgrad = old.getSumOfSquaredGradients();
- old_sum_squared_delta_x = old.getSumOfSquaredDeltaX();
- }
+@Deprecated
+public final class AdaDeltaUDTF extends GeneralRegressionUDTF {
- float new_scaled_sum_sq_grad = (decay * old_scaled_sum_sqgrad) + ((1.f - decay) * g_g);
- float dx = (float) Math.sqrt((old_sum_squared_delta_x + eps)
- / (old_scaled_sum_sqgrad * scaling + eps))
- * gradient;
- float new_sum_squared_delta_x = (decay * old_sum_squared_delta_x)
- + ((1.f - decay) * dx * dx);
- float new_w = old_w + (dx * xi);
- return new WeightValueParamsF2(new_w, new_scaled_sum_sq_grad, new_sum_squared_delta_x);
+ public AdaDeltaUDTF() {
+ optimizerOptions.put("optimizer", "AdaDelta");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/AdaGradUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/AdaGradUDTF.java b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
index de48d97..4b5f019 100644
--- a/core/src/main/java/hivemall/regression/AdaGradUDTF.java
+++ b/core/src/main/java/hivemall/regression/AdaGradUDTF.java
@@ -18,124 +18,14 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
-import hivemall.model.FeatureValue;
-import hivemall.model.IWeightValue;
-import hivemall.model.WeightValue.WeightValueParamsF1;
-import hivemall.utils.lang.Primitives;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
/**
* ADAGRAD algorithm with element-wise adaptive learning rates.
*/
-@Description(
- name = "train_adagrad_regr",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class AdaGradUDTF extends RegressionBaseUDTF {
-
- private float eta;
- private float eps;
- private float scaling;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- StructObjectInspector oi = super.initialize(argOIs);
- model.configureParams(true, false, false);
- return oi;
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("eta", "eta0", true, "The initial learning rate [default 1.0]");
- opts.addOption("eps", true, "A constant used in the denominator of AdaGrad [default 1.0]");
- opts.addOption("scale", true,
- "Internal scaling/descaling factor for cumulative weights [100]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
- if (cl == null) {
- this.eta = 1.f;
- this.eps = 1.f;
- this.scaling = 100f;
- } else {
- this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f);
- this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), 1.f);
- this.scaling = Primitives.parseFloat(cl.getOptionValue("scale"), 100f);
- }
- return cl;
- }
-
- @Override
- protected final void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected void update(@Nonnull final FeatureValue[] features, float target, float predicted) {
- float gradient = LossFunctions.logisticLoss(target, predicted);
- onlineUpdate(features, gradient);
- }
-
- @Override
- protected void onlineUpdate(@Nonnull final FeatureValue[] features, float gradient) {
- final float g_g = gradient * (gradient / scaling);
-
- for (FeatureValue f : features) {// w[i] += y * x[i]
- if (f == null) {
- continue;
- }
- Object x = f.getFeature();
- float xi = f.getValueAsFloat();
-
- IWeightValue old_w = model.get(x);
- IWeightValue new_w = getNewWeight(old_w, xi, gradient, g_g);
- model.set(x, new_w);
- }
- }
-
- @Nonnull
- protected IWeightValue getNewWeight(@Nullable final IWeightValue old, final float xi,
- final float gradient, final float g_g) {
- float old_w = 0.f;
- float scaled_sum_sqgrad = 0.f;
-
- if (old != null) {
- old_w = old.get();
- scaled_sum_sqgrad = old.getSumOfSquaredGradients();
- }
- scaled_sum_sqgrad += g_g;
-
- float coeff = eta(scaled_sum_sqgrad) * gradient;
- float new_w = old_w + (coeff * xi);
- return new WeightValueParamsF1(new_w, scaled_sum_sqgrad);
- }
+@Deprecated
+public final class AdaGradUDTF extends GeneralRegressionUDTF {
- protected float eta(final double scaledSumOfSquaredGradients) {
- double sumOfSquaredGradients = scaledSumOfSquaredGradients * scaling;
- //return eta / (float) Math.sqrt(sumOfSquaredGradients);
- return eta / (float) Math.sqrt(eps + sumOfSquaredGradients); // always less than eta0
+ public AdaGradUDTF() {
+ optimizerOptions.put("optimizer", "AdaGrad");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
new file mode 100644
index 0000000..2a8b543
--- /dev/null
+++ b/core/src/main/java/hivemall/regression/GeneralRegressionUDTF.java
@@ -0,0 +1,125 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.regression;
+
+import java.util.HashMap;
+import java.util.Map;
+import javax.annotation.Nonnull;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+
+import hivemall.optimizer.LossFunctions;
+import hivemall.model.FeatureValue;
+
+/**
+ * A general regression class with replaceable optimization functions.
+ */
+public class GeneralRegressionUDTF extends RegressionBaseUDTF {
+
+ protected final Map<String, String> optimizerOptions;
+
+ public GeneralRegressionUDTF() {
+ this.optimizerOptions = new HashMap<String, String>();
+ // Set default values
+ optimizerOptions.put("optimizer", "adadelta");
+ optimizerOptions.put("eta", "fixed");
+ optimizerOptions.put("eta0", "1.0");
+ optimizerOptions.put("t", "10000");
+ optimizerOptions.put("power_t", "0.1");
+ optimizerOptions.put("eps", "1e-6");
+ optimizerOptions.put("rho", "0.95");
+ optimizerOptions.put("scale", "100.0");
+ optimizerOptions.put("lambda", "1.0");
+ }
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if(argOIs.length != 2 && argOIs.length != 3) {
+ throw new UDFArgumentException(
+ this.getClass().getSimpleName()
+ + " takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target "
+ + "[, constant string options]");
+ }
+ return super.initialize(argOIs);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = super.getOptions();
+ opts.addOption("optimizer", "opt", true, "Optimizer to update weights [default: adadelta]");
+ opts.addOption("eta", true, " ETA estimator to compute delta [default: fixed]");
+ opts.addOption("eta0", true, "Initial learning rate [default 1.0]");
+ opts.addOption("t", "total_steps", true, "Total of n_samples * epochs time steps [default: 10000]");
+ opts.addOption("power_t", true, "Exponent for inverse scaling learning rate [default 0.1]");
+ opts.addOption("eps", true, "Denominator value of AdaDelta/AdaGrad [default 1e-6]");
+ opts.addOption("rho", "decay", true, "Decay rate [default 0.95]");
+ opts.addOption("scale", true, "Scaling factor for cumulative weights [100.0]");
+ opts.addOption("regularization", "reg", true, "Regularization type [default not-defined]");
+ opts.addOption("lambda", true, "Regularization term on weights [default 1.0]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final CommandLine cl = super.processOptions(argOIs);
+ if(cl != null) {
+ for(final Option opt: cl.getOptions()) {
+ optimizerOptions.put(opt.getOpt(), opt.getValue());
+ }
+ }
+ return cl;
+ }
+
+ @Override
+ protected Map<String, String> getOptimzierOptions() {
+ return optimizerOptions;
+ }
+
+ @Override
+ protected final void checkTargetValue(final float target) throws UDFArgumentException {
+ if(target < 0.f || target > 1.f) {
+ throw new UDFArgumentException("target must be in range 0 to 1: " + target);
+ }
+ }
+
+ @Override
+ protected void update(@Nonnull final FeatureValue[] features, final float target,
+ final float predicted) {
+ if(is_mini_batch) {
+ throw new UnsupportedOperationException(
+ this.getClass().getSimpleName() + " supports no `is_mini_batch` mode");
+ } else {
+ float loss = LossFunctions.logisticLoss(target, predicted);
+ for(FeatureValue f : features) {
+ Object feature = f.getFeature();
+ float xi = f.getValueAsFloat();
+ float weight = model.getWeight(feature);
+ float new_weight = optimizerImpl.computeUpdatedValue(feature, weight, -loss * xi);
+ model.setWeight(feature, new_weight);
+ }
+ optimizerImpl.proceedStep();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/LogressUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/LogressUDTF.java b/core/src/main/java/hivemall/regression/LogressUDTF.java
index ca3da71..ea05da3 100644
--- a/core/src/main/java/hivemall/regression/LogressUDTF.java
+++ b/core/src/main/java/hivemall/regression/LogressUDTF.java
@@ -18,65 +18,12 @@
*/
package hivemall.regression;
-import hivemall.common.EtaEstimator;
-import hivemall.common.LossFunctions;
+@Deprecated
+public final class LogressUDTF extends GeneralRegressionUDTF {
-import org.apache.commons.cli.CommandLine;
-import org.apache.commons.cli.Options;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-
-@Description(
- name = "logress",
- value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options])"
- + " - Returns a relation consists of <{int|bigint|string} feature, float weight>")
-public final class LogressUDTF extends RegressionBaseUDTF {
-
- private EtaEstimator etaEstimator;
-
- @Override
- public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
- final int numArgs = argOIs.length;
- if (numArgs != 2 && numArgs != 3) {
- throw new UDFArgumentException(
- "LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
- }
-
- return super.initialize(argOIs);
- }
-
- @Override
- protected Options getOptions() {
- Options opts = super.getOptions();
- opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps");
- opts.addOption("power_t", true,
- "The exponent for inverse scaling learning rate [default 0.1]");
- opts.addOption("eta0", true, "The initial learning rate [default 0.1]");
- return opts;
- }
-
- @Override
- protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
- CommandLine cl = super.processOptions(argOIs);
-
- this.etaEstimator = EtaEstimator.get(cl);
- return cl;
- }
-
- @Override
- protected void checkTargetValue(final float target) throws UDFArgumentException {
- if (target < 0.f || target > 1.f) {
- throw new UDFArgumentException("target must be in range 0 to 1: " + target);
- }
- }
-
- @Override
- protected float computeUpdate(final float target, final float predicted) {
- float eta = etaEstimator.eta(count);
- float gradient = LossFunctions.logisticLoss(target, predicted);
- return eta * gradient;
+ public LogressUDTF() {
+ optimizerOptions.put("optimizer", "SGD");
+ optimizerOptions.put("eta", "fixed");
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
index c089946..e1afe2f 100644
--- a/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
+++ b/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
@@ -18,7 +18,7 @@
*/
package hivemall.regression;
-import hivemall.common.LossFunctions;
+import hivemall.optimizer.LossFunctions;
import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
index 561d4f7..7dc8538 100644
--- a/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
+++ b/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
@@ -25,6 +25,7 @@ import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
+import hivemall.optimizer.Optimizer;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.FloatAccumulator;
@@ -64,6 +65,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
private boolean parseFeature;
protected PredictionModel model;
+ protected Optimizer optimizerImpl;
protected int count;
// The accumulated delta of each weight values.
@@ -87,6 +89,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
if (preloadedModelFile != null) {
loadPredictionModel(model, preloadedModelFile, featureOutputOI);
}
+ this.optimizerImpl = createOptimizer();
this.count = 0;
this.sampled = 0;
@@ -235,7 +238,7 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
protected void update(@Nonnull final FeatureValue[] features, final float target,
final float predicted) {
- final float grad = computeUpdate(target, predicted);
+ final float grad = computeGradient(target, predicted);
if (is_mini_batch) {
accumulateUpdate(features, grad);
@@ -247,12 +250,9 @@ public abstract class RegressionBaseUDTF extends LearnerBaseUDTF {
}
}
- protected float computeUpdate(float target, float predicted) {
- throw new IllegalStateException();
- }
-
- protected IWeightValue getNewWeight(IWeightValue old_w, float delta) {
- throw new IllegalStateException();
+ // Compute a gradient by using a loss function in derived classes
+ protected float computeGradient(float target, float predicted) {
+ throw new UnsupportedOperationException();
}
protected final void accumulateUpdate(@Nonnull final FeatureValue[] features, final float coeff) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/core/src/test/java/hivemall/optimizer/OptimizerTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/optimizer/OptimizerTest.java b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
new file mode 100644
index 0000000..cfcfa79
--- /dev/null
+++ b/core/src/test/java/hivemall/optimizer/OptimizerTest.java
@@ -0,0 +1,172 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.optimizer;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public final class OptimizerTest {
+
+ @Test
+ public void testIllegalOptimizer() {
+ try {
+ final Map<String, String> emptyOptions = new HashMap<String, String>();
+ DenseOptimizerFactory.create(1024, emptyOptions);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "illegal");
+ DenseOptimizerFactory.create(1024, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> emptyOptions = new HashMap<String, String>();
+ SparseOptimizerFactory.create(1024, emptyOptions);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "illegal");
+ SparseOptimizerFactory.create(1024, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ }
+
+ @Test
+ public void testOptimizerFactory() {
+ final Map<String, String> options = new HashMap<String, String>();
+ final String[] regTypes = new String[] {"NO", "L1", "L2"};
+ for(final String regType : regTypes) {
+ options.put("optimizer", "SGD");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof Optimizer.SGD);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof Optimizer.SGD);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "AdaDelta");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaDelta);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaDelta);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "AdaGrad");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.AdaGrad);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.AdaGrad);
+ }
+ for(final String regType : regTypes) {
+ options.put("optimizer", "Adam");
+ options.put("regularization", regType);
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.Adam);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.Adam);
+ }
+
+ // We need special handling for `Optimizer#RDA`
+ options.put("optimizer", "AdaGrad");
+ options.put("regularization", "RDA");
+ Assert.assertTrue(DenseOptimizerFactory.create(8, options) instanceof DenseOptimizerFactory.RDA);
+ Assert.assertTrue(SparseOptimizerFactory.create(8, options) instanceof SparseOptimizerFactory.RDA);
+
+ // `SGD`, `AdaDelta`, and `Adam` currently does not support `RDA`
+ for(final String optimizerType : new String[] {"SGD", "AdaDelta", "Adam"}) {
+ options.put("optimizer", optimizerType);
+ try {
+ DenseOptimizerFactory.create(8, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ try {
+ SparseOptimizerFactory.create(8, options);
+ Assert.fail();
+ } catch (IllegalArgumentException e) {
+ // tests passed
+ }
+ }
+ }
+
+ private void testUpdateWeights(Optimizer optimizer, int numUpdates, int initSize) {
+ final float[] weights = new float[initSize * 2];
+ final Random rnd = new Random();
+ try {
+ for(int i = 0; i < numUpdates; i++) {
+ int index = rnd.nextInt(initSize);
+ weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f);
+ }
+ for(int i = 0; i < numUpdates; i++) {
+ int index = rnd.nextInt(initSize * 2);
+ weights[index] = optimizer.computeUpdatedValue(index, weights[index], 0.1f);
+ }
+ } catch(Exception e) {
+ Assert.fail("failed to update weights: " + e.getMessage());
+ }
+ }
+
+ private void testOptimizer(final Map<String, String> options, int numUpdates, int initSize) {
+ final Map<String, String> testOptions = new HashMap<String, String>(options);
+ final String[] regTypes = new String[] {"NO", "L1", "L2", "RDA"};
+ for(final String regType : regTypes) {
+ options.put("regularization", regType);
+ testUpdateWeights(DenseOptimizerFactory.create(1024, testOptions), 65536, 1024);
+ testUpdateWeights(SparseOptimizerFactory.create(1024, testOptions), 65536, 1024);
+ }
+ }
+
+ @Test
+ public void testSGDOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "SGD");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdaDeltaOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "AdaDelta");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdaGradOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "AdaGrad");
+ testOptimizer(options, 65536, 1024);
+ }
+
+ @Test
+ public void testAdamOptimizer() {
+ final Map<String, String> options = new HashMap<String, String>();
+ options.put("optimizer", "Adam");
+ testOptimizer(options, 65536, 1024);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
----------------------------------------------------------------------
diff --git a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
index 0b1455c..38792d8 100644
--- a/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
+++ b/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
@@ -55,7 +55,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -93,7 +93,7 @@ public class MixServerTest extends HivemallTestBase {
waitForState(server, ServerState.RUNNING);
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -151,7 +151,7 @@ public class MixServerTest extends HivemallTestBase {
}
private static void invokeClient(String groupId, int serverPort) throws InterruptedException {
- PredictionModel model = new DenseModel(16777216, false);
+ PredictionModel model = new DenseModel(16777216);
model.configureClock();
MixClient client = null;
try {
@@ -296,10 +296,10 @@ public class MixServerTest extends HivemallTestBase {
serverExec.shutdown();
}
- private static void invokeClient01(String groupId, int serverPort, boolean denseModel,
- boolean cancelMix) throws InterruptedException {
- PredictionModel model = denseModel ? new DenseModel(100, false) : new SparseModel(100,
- false);
+ private static void invokeClient01(String groupId, int serverPort, boolean denseModel, boolean cancelMix)
+ throws InterruptedException {
+ PredictionModel model = denseModel ? new DenseModel(100)
+ : new SparseModel(100, false);
model.configureClock();
MixClient client = null;
try {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index bab5a29..ccdace0 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -13,6 +13,9 @@ CREATE FUNCTION hivemall_version as 'hivemall.HivemallVersionUDF' USING JAR '${h
-- binary classification --
---------------------------
+DROP FUNCTION IF EXISTS train_classifier;
+CREATE FUNCTION train_classifier as 'hivemall.classifier.GeneralClassifierUDTF' USING JAR '${hivemall_jar}';
+
DROP FUNCTION IF EXISTS train_perceptron;
CREATE FUNCTION train_perceptron as 'hivemall.classifier.PerceptronUDTF' USING JAR '${hivemall_jar}';
@@ -45,7 +48,7 @@ CREATE FUNCTION train_adagrad_rda as 'hivemall.classifier.AdaGradRDAUDTF' USING
--------------------------------
-- Multiclass classification --
---------------------------------
+--------------------------------
DROP FUNCTION IF EXISTS train_multiclass_perceptron;
CREATE FUNCTION train_multiclass_perceptron as 'hivemall.classifier.multiclass.MulticlassPerceptronUDTF' USING JAR '${hivemall_jar}';
@@ -312,6 +315,13 @@ CREATE FUNCTION tf as 'hivemall.ftvec.text.TermFrequencyUDAF' USING JAR '${hivem
-- Regression functions --
--------------------------
+DROP FUNCTION IF EXISTS train_regression;
+CREATE FUNCTION train_regression as 'hivemall.classifier.GeneralRegressionUDTF' USING JAR '${hivemall_jar}';
+
+DROP FUNCTION IF EXISTS train_logregr;
+CREATE FUNCTION train_logregr as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}';
+
+-- alias for backward compatibility
DROP FUNCTION IF EXISTS logress;
CREATE FUNCTION logress as 'hivemall.regression.LogressUDTF' USING JAR '${hivemall_jar}';
@@ -599,3 +609,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U
DROP FUNCTION xgboost_multiclass_predict;
CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}';
+=======
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f81948c5/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 315b4d2..d60fd7f 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -9,6 +9,9 @@ create temporary function hivemall_version as 'hivemall.HivemallVersionUDF';
-- binary classification --
---------------------------
+drop temporary function train_classifier;
+create temporary function train_classifier as 'hivemall.regression.GeneralClassifierUDTF';
+
drop temporary function train_perceptron;
create temporary function train_perceptron as 'hivemall.classifier.PerceptronUDTF';
@@ -308,6 +311,13 @@ create temporary function tf as 'hivemall.ftvec.text.TermFrequencyUDAF';
-- Regression functions --
--------------------------
+drop temporary function train_regression;
+create temporary function train_regression as 'hivemall.regression.GeneralRegressionUDTF';
+
+drop temporary function train_logregr;
+create temporary function train_logregr as 'hivemall.regression.LogressUDTF';
+
+-- alias for backward compatibility
drop temporary function logress;
create temporary function logress as 'hivemall.regression.LogressUDTF';
@@ -628,5 +638,3 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
-
-
[06/50] [abbrv] incubator-hivemall git commit: change interface of
chi2
Posted by my...@apache.org.
change interface of chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/7b07e4a6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/7b07e4a6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/7b07e4a6
Branch: refs/heads/JIRA-22/pr-385
Commit: 7b07e4a6e1f700ba0a6e5b68659a040a3d89aa2f
Parents: d0e97e6
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 12:03:44 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 12:11:42 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/ChiSquareTestUDF.java | 21 ----
.../hivemall/ftvec/selection/ChiSquareUDF.java | 124 +++++++++++++++++--
.../ftvec/selection/DissociationDegreeUDF.java | 88 -------------
.../java/hivemall/utils/math/StatsUtils.java | 49 ++++++--
4 files changed, 155 insertions(+), 127 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
deleted file mode 100644
index d367085..0000000
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java
+++ /dev/null
@@ -1,21 +0,0 @@
-package hivemall.ftvec.selection;
-
-import hivemall.utils.math.StatsUtils;
-import org.apache.hadoop.hive.ql.exec.Description;
-
-import javax.annotation.Nonnull;
-
-@Description(name = "chi2_test",
- value = "_FUNC_(array<number> expected, array<number> observed) - Returns p-value as double")
-public class ChiSquareTestUDF extends DissociationDegreeUDF {
- @Override
- double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) {
- return StatsUtils.chiSquareTest(expected, observed);
- }
-
- @Override
- @Nonnull
- String getFuncName() {
- return "chi2_test";
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index 937b1bd..1954e33 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -1,21 +1,131 @@
package hivemall.ftvec.selection;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.StatsUtils;
import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import javax.annotation.Nonnull;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
@Description(name = "chi2",
- value = "_FUNC_(array<number> expected, array<number> observed) - Returns chi2-value as double")
-public class ChiSquareUDF extends DissociationDegreeUDF {
+ value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)" +
+ " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>")
+public class ChiSquareUDF extends GenericUDF {
+ private ListObjectInspector observedOI;
+ private ListObjectInspector observedRowOI;
+ private PrimitiveObjectInspector observedElOI;
+ private ListObjectInspector expectedOI;
+ private ListObjectInspector expectedRowOI;
+ private PrimitiveObjectInspector expectedElOI;
+
@Override
- double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) {
- return StatsUtils.chiSquare(expected, observed);
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListListOI(OIs[0])){
+ throw new UDFArgumentTypeException(0, "Only array<array<number>> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `observed`");
+ }
+
+ if (!HiveUtils.isNumberListListOI(OIs[1])){
+ throw new UDFArgumentTypeException(1, "Only array<array<number>> type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `expected`");
+ }
+
+ observedOI = HiveUtils.asListOI(OIs[1]);
+ observedRowOI=HiveUtils.asListOI(observedOI.getListElementObjectInspector());
+ observedElOI = HiveUtils.asDoubleCompatibleOI( observedRowOI.getListElementObjectInspector());
+ expectedOI = HiveUtils.asListOI(OIs[0]);
+ expectedRowOI=HiveUtils.asListOI(expectedOI.getListElementObjectInspector());
+ expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector());
+
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(
+ Arrays.asList("chi2_vals", "p_vals"), fieldOIs);
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ List observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features)
+ List expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features)
+
+ Preconditions.checkNotNull(observedObj);
+ Preconditions.checkNotNull(expectedObj);
+ final int nClasses = observedObj.size();
+ Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows
+
+ int nFeatures=-1;
+ double[] observedRow=null; // to reuse
+ double[] expectedRow=null; // to reuse
+ double[][] observed =null; // shape = (#features, #classes)
+ double[][] expected = null; // shape = (#features, #classes)
+
+ // explode and transpose matrix
+ for(int i=0;i<nClasses;i++){
+ if(i==0){
+ // init
+ observedRow=HiveUtils.asDoubleArray(observedObj.get(i),observedRowOI,observedElOI,false);
+ expectedRow=HiveUtils.asDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI, false);
+ nFeatures = observedRow.length;
+ observed=new double[nFeatures][nClasses];
+ expected = new double[nFeatures][nClasses];
+ }else{
+ HiveUtils.toDoubleArray(observedObj.get(i),observedRowOI,observedElOI,observedRow,false);
+ HiveUtils.toDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI,expectedRow, false);
+ }
+
+ for(int j=0;j<nFeatures;j++){
+ observed[j][i] = observedRow[j];
+ expected[j][i] = expectedRow[j];
+ }
+ }
+
+ final Map.Entry<double[],double[]> chi2 = StatsUtils.chiSquares(observed,expected);
+
+ final Object[] result = new Object[2];
+ result[0] = WritableUtils.toWritableList(chi2.getKey());
+ result[1]=WritableUtils.toWritableList(chi2.getValue());
+ return result;
}
@Override
- @Nonnull
- String getFuncName() {
- return "chi2";
+ public String getDisplayString(String[] children) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("chi2");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java b/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
deleted file mode 100644
index 0acae82..0000000
--- a/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java
+++ /dev/null
@@ -1,88 +0,0 @@
-package hivemall.ftvec.selection;
-
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.Preconditions;
-import hivemall.utils.math.StatsUtils;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import javax.annotation.Nonnull;
-
-@Description(name = "",
- value = "_FUNC_(array<number> expected, array<number> observed) - Returns dissociation degree as double")
-public abstract class DissociationDegreeUDF extends GenericUDF {
- private ListObjectInspector expectedOI;
- private DoubleObjectInspector expectedElOI;
- private ListObjectInspector observedOI;
- private DoubleObjectInspector observedElOI;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
- if (OIs.length != 2) {
- throw new UDFArgumentLengthException("Specify two arguments.");
- }
-
- if (!HiveUtils.isListOI(OIs[0])
- || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[0]).getListElementObjectInspector())){
- throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
- + OIs[0].getTypeName() + " was passed as `expected`");
- }
-
- if (!HiveUtils.isListOI(OIs[1])
- || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())){
- throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but "
- + OIs[1].getTypeName() + " was passed as `observed`");
- }
-
- expectedOI = (ListObjectInspector) OIs[0];
- expectedElOI = (DoubleObjectInspector) expectedOI.getListElementObjectInspector();
- observedOI = (ListObjectInspector) OIs[1];
- observedElOI = (DoubleObjectInspector) observedOI.getListElementObjectInspector();
-
- return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
- }
-
- @Override
- public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
- final double[] expected = HiveUtils.asDoubleArray(dObj[0].get(),expectedOI,expectedElOI);
- final double[] observed = HiveUtils.asDoubleArray(dObj[1].get(),observedOI,observedElOI);
-
- Preconditions.checkNotNull(expected);
- Preconditions.checkNotNull(observed);
- Preconditions.checkArgument(expected.length == observed.length);
-
- final double dissociation = calcDissociation(expected,observed);
-
- return new DoubleWritable(dissociation);
- }
-
- @Override
- public String getDisplayString(String[] children) {
- final StringBuilder sb = new StringBuilder();
- sb.append(getFuncName());
- sb.append("(");
- if (children.length > 0) {
- sb.append(children[0]);
- for (int i = 1; i < children.length; i++) {
- sb.append(", ");
- sb.append(children[i]);
- }
- }
- sb.append(")");
- return sb.toString();
- }
-
- abstract double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed);
-
- @Nonnull
- abstract String getFuncName();
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index 7633419..f9d0f30 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -29,6 +29,9 @@ import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
+import java.util.AbstractMap;
+import java.util.Map;
+
public final class StatsUtils {
private StatsUtils() {}
@@ -191,24 +194,24 @@ public final class StatsUtils {
}
/**
- * @param expected mean vector whose value is expected
* @param observed mean vector whose value is observed
- * @return chi2-value
+ * @param expected mean vector whose value is expected
+ * @return chi2 value
*/
- public static double chiSquare(@Nonnull final double[] expected, @Nonnull final double[] observed) {
- Preconditions.checkArgument(expected.length == observed.length);
+ public static double chiSquare(@Nonnull final double[] observed, @Nonnull final double[] expected) {
+ Preconditions.checkArgument(observed.length == expected.length);
- double sumExpected = 0.d;
double sumObserved = 0.d;
+ double sumExpected = 0.d;
for (int ratio = 0; ratio < observed.length; ++ratio) {
- sumExpected += expected[ratio];
sumObserved += observed[ratio];
+ sumExpected += expected[ratio];
}
double var15 = 1.d;
boolean rescale = false;
- if (Math.abs(sumExpected - sumObserved) > 1.e-5) {
+ if (Math.abs(sumObserved - sumExpected) > 1.e-5) {
var15 = sumObserved / sumExpected;
rescale = true;
}
@@ -230,12 +233,36 @@ public final class StatsUtils {
}
/**
- * @param expected means vector whose value is expected
* @param observed means vector whose value is observed
- * @return p-value
+ * @param expected means vector whose value is expected
+ * @return p value
*/
- public static double chiSquareTest(@Nonnull final double[] expected,@Nonnull final double[] observed) {
+ public static double chiSquareTest(@Nonnull final double[] observed, @Nonnull final double[] expected) {
ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.d);
- return 1.d - distribution.cumulativeProbability(chiSquare(expected, observed));
+ return 1.d - distribution.cumulativeProbability(chiSquare(observed,expected));
+ }
+
+ /**
+ * This method offers effective calculation for multiple entries rather than calculation individually
+ * @param observeds means matrix whose values are observed
+ * @param expecteds means matrix
+ * @return (chi2 value[], p value[])
+ */
+ public static Map.Entry<double[],double[]> chiSquares(@Nonnull final double[][] observeds, @Nonnull final double[][] expecteds){
+ Preconditions.checkArgument(observeds.length == expecteds.length);
+
+ final int len = expecteds.length;
+ final int lenOfEach = expecteds[0].length;
+
+ final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)lenOfEach - 1.d);
+
+ final double[] chi2s = new double[len];
+ final double[] ps = new double[len];
+ for(int i=0;i<len;i++){
+ chi2s[i] = chiSquare(observeds[i],expecteds[i]);
+ ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]);
+ }
+
+ return new AbstractMap.SimpleEntry<double[], double[]>(chi2s,ps);
}
}
[45/50] [abbrv] incubator-hivemall git commit: Add feature selection
gitbook (#386)
Posted by my...@apache.org.
Add feature selection gitbook (#386)
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6549ef51
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6549ef51
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6549ef51
Branch: refs/heads/JIRA-22/pr-385
Commit: 6549ef5104883a9529dfd9fc52b2b24843076fbb
Parents: e44a413
Author: amaya <am...@users.noreply.github.com>
Authored: Wed Nov 23 21:16:10 2016 +0900
Committer: Makoto YUI <yu...@gmail.com>
Committed: Wed Nov 23 21:16:10 2016 +0900
----------------------------------------------------------------------
docs/gitbook/SUMMARY.md | 2 +
.../gitbook/ft_engineering/feature_selection.md | 151 +++++++++++++++++++
2 files changed, 153 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6549ef51/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index c333c98..33bb46c 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -61,6 +61,8 @@
* [Vectorize Features](ft_engineering/vectorizer.md)
* [Quantify non-number features](ft_engineering/quantify.md)
+* [Feature selection](ft_engineering/feature_selection.md)
+
## Part IV - Evaluation
* [Statistical evaluation of a prediction model](eval/stat_eval.md)
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6549ef51/docs/gitbook/ft_engineering/feature_selection.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/ft_engineering/feature_selection.md b/docs/gitbook/ft_engineering/feature_selection.md
new file mode 100644
index 0000000..8b522c6
--- /dev/null
+++ b/docs/gitbook/ft_engineering/feature_selection.md
@@ -0,0 +1,151 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Feature selection is the process which selects a subset consisting of influential features from miscellaneous ones.
+It is an important technique to **enhance results**, **shorten training time** and **make features human-understandable**.
+
+## Selecting methods supported by Hivemall
+* Chi-square (Chi2)
+ * For non-negative data only
+* Signal Noise Ratio (SNR)
+* ~~Minimum Redundancy Maximum Relevance (mRMR)~~
+ * Contributions are welcome!
+
+## Usage
+1. Create importance list for feature selection
+ * chi2/SNR
+2. Filter features
+ * Select top-k features based on importance list
+
+
+## Example - Chi2
+``` sql
+CREATE TABLE input (
+ X array<double>, -- features
+ Y array<int> -- binarized label
+);
+
+WITH stats AS (
+ SELECT
+ -- [UDAF] transpose_and_dot(Y::array<number>, X::array<number>)::array<array<double>>
+ transpose_and_dot(Y, X) AS observed, -- array<array<double>>, shape = (n_classes, n_features)
+ array_sum(X) AS feature_count, -- n_features col vector, shape = (1, array<double>)
+ array_avg(Y) AS class_prob -- n_class col vector, shape = (1, array<double>)
+ FROM
+ input
+),
+test AS (
+ SELECT
+ transpose_and_dot(class_prob, feature_count) AS expected -- array<array<double>>, shape = (n_class, n_features)
+ FROM
+ stats
+),
+chi2 AS (
+ SELECT
+ -- [UDAF] chi2(observed::array<array<double>>, expected::array<array<double>>)::struct<array<double>, array<double>>
+ chi2(observed, expected) AS chi2s -- struct<array<double>, array<double>>, each shape = (1, n_features)
+ FROM
+ test JOIN stats;
+)
+SELECT
+ -- [UDF] select_k_best(X::array<number>, importance_list::array<int> k::int)::array<double>
+ select_k_best(X, chi2s.chi2, $[k}) -- top-k feature selection based on chi2 score
+FROM
+ input JOIN chi2;
+```
+
+
+## Example - SNR
+``` sql
+CREATE TABLE input (
+ X array<double>, -- features
+ Y array<int> -- binarized label
+);
+
+WITH snr AS (
+ -- [UDAF] snr(features::array<number>, labels::array<int>)::array<double>
+ SELECT snr(X, Y) AS snr FROM input -- aggregated SNR as array<double>, shape = (1, #features)
+)
+SELECT select_k_best(X, snr, ${k}) FROM input JOIN snr;
+```
+
+
+## UDF details
+### Common
+#### [UDAF] `transpose_and_dot(X::array<number>, Y::array<number>)::array<array<double>>`
+##### Input
+
+| array<number> X | array<number> Y |
+| :-: | :-: |
+| a row of matrix | a row of matrix |
+##### Output
+
+| array<array<double>> dotted |
+| :-: |
+| `dot(X.T, Y)`, shape = (X.#cols, Y.#cols) |
+#### [UDF] `select_k_best(X::array<number>, importance_list::array<int> k::int)::array<double>`
+##### Input
+
+| array<number> X | array<int> importance list | int k |
+| :-: | :-: | :-: |
+| array | the larger, the more important | top-? |
+##### Output
+
+| array<array<double>> k-best elements |
+| :-: |
+| top-k elements from X based on indices of importance list |
+
+#### Note
+- Current implementation expects **_ALL each `importance_list` and `k` are equal**_. It maybe confuse us.
+ - Future WA: add option showing use of common `importance_list` and `k`
+
+
+### Chi2
+#### [UDF] `chi2(observed::array<array<number>>, expected::array<array<number>>)::struct<array<double>, array<double>>`
+##### Input
+
+both `observed` and `expected`, shape = (#classes, #features)
+
+| array<number> observed | array<number> expected |
+| :-: | :-: |
+| observed features | expected features, `dot(class_prob.T, feature_count)` |
+
+##### Output
+
+| struct<array<double>, array<double>> importance lists |
+| :-: |
+| chi2-values and p-values each feature, each shape = (1, #features) |
+
+
+### SNR
+#### [UDAF] `snr(X::array<number>, Y::array<int>)::array<double>`
+##### Input
+
+| array<number> X | array<int> Y |
+| :-: | :-: |
+| a row of matrix, overall shape = (#samples, #features) | a row of one-hot matrix, overall shape = (#samples, #classes) |
+
+##### Output
+
+| array<double> importance list |
+| :-: |
+| snr values of each feature, shape = (1, #features) |
+
+#### Note
+* Essentially, there is no need to one-hot vectorizing, but fitting its interface to chi2's one
\ No newline at end of file
[08/50] [abbrv] incubator-hivemall git commit: add subarray_by_indices
Posted by my...@apache.org.
add subarray_by_indices
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1ab9b097
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1ab9b097
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1ab9b097
Branch: refs/heads/JIRA-22/pr-385
Commit: 1ab9b0974ca4203c00175469b7b75d5b65209547
Parents: e9d1a94
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 16:56:15 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:37:46 2016 +0900
----------------------------------------------------------------------
.../tools/array/SubarrayByIndicesUDF.java | 93 ++++++++++++++++++++
1 file changed, 93 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1ab9b097/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java b/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
new file mode 100644
index 0000000..f476589
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
@@ -0,0 +1,93 @@
+package hivemall.tools.array;
+
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+
+@Description(name = "subarray_by_indices",
+ value = "_FUNC_(array<number> input, array<int> indices)" +
+ " - Returns subarray selected by given indices as array<number>")
+public class SubarrayByIndicesUDF extends GenericUDF {
+ private ListObjectInspector inputOI;
+ private PrimitiveObjectInspector elementOI;
+ private ListObjectInspector indicesOI;
+ private PrimitiveObjectInspector indexOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `input`");
+ }
+ if (!HiveUtils.isListOI(OIs[1])
+ || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
+ throw new UDFArgumentTypeException(0, "Only array<int> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `indices`");
+ }
+
+ inputOI = HiveUtils.asListOI(OIs[0]);
+ elementOI = HiveUtils.asDoubleCompatibleOI(inputOI.getListElementObjectInspector());
+ indicesOI = HiveUtils.asListOI(OIs[1]);
+ indexOI = HiveUtils.asIntegerOI(indicesOI.getListElementObjectInspector());
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ final double[] input = HiveUtils.asDoubleArray(dObj[0].get(), inputOI, elementOI);
+ final List indices = indicesOI.getList(dObj[1].get());
+
+ Preconditions.checkNotNull(input);
+ Preconditions.checkNotNull(indices);
+
+ List<DoubleWritable> result = new ArrayList<DoubleWritable>();
+ for (Object indexObj : indices) {
+ int index = PrimitiveObjectInspectorUtils.getInt(indexObj, indexOI);
+ if (index > input.length - 1) {
+ throw new ArrayIndexOutOfBoundsException(index);
+ }
+
+ result.add(new DoubleWritable(input[index]));
+ }
+
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("subarray_by_indices");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}
[21/50] [abbrv] incubator-hivemall git commit: Implement initial
SST-based change-point detector
Posted by my...@apache.org.
Implement initial SST-based change-point detector
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3ebd771e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3ebd771e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3ebd771e
Branch: refs/heads/JIRA-22/pr-356
Commit: 3ebd771ee4bebf14769b7c240f8b28b9d5d10e86
Parents: 89ec56e
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Mon Sep 26 17:12:01 2016 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Mon Sep 26 17:12:01 2016 +0900
----------------------------------------------------------------------
.../java/hivemall/anomaly/SSTChangePoint.java | 118 +++++++++++
.../hivemall/anomaly/SSTChangePointUDF.java | 197 +++++++++++++++++++
.../hivemall/anomaly/SSTChangePointTest.java | 111 +++++++++++
3 files changed, 426 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePoint.java b/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
new file mode 100644
index 0000000..e693bd4
--- /dev/null
+++ b/core/src/main/java/hivemall/anomaly/SSTChangePoint.java
@@ -0,0 +1,118 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.anomaly.SSTChangePointUDF.SSTChangePointInterface;
+import hivemall.anomaly.SSTChangePointUDF.Parameters;
+import hivemall.utils.collections.DoubleRingBuffer;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.Arrays;
+
+import javax.annotation.Nonnull;
+
+final class SSTChangePoint implements SSTChangePointInterface {
+
+ @Nonnull
+ private final PrimitiveObjectInspector oi;
+
+ @Nonnull
+ private final int window;
+ @Nonnull
+ private final int nPastWindow;
+ @Nonnull
+ private final int nCurrentWindow;
+ @Nonnull
+ private final int pastSize;
+ @Nonnull
+ private final int currentSize;
+ @Nonnull
+ private final int currentOffset;
+ @Nonnull
+ private final int r;
+
+ @Nonnull
+ private final DoubleRingBuffer xRing;
+ @Nonnull
+ private final double[] xSeries;
+
+ SSTChangePoint(@Nonnull Parameters params, @Nonnull PrimitiveObjectInspector oi) {
+ this.oi = oi;
+
+ this.window = params.w;
+ this.nPastWindow = params.n;
+ this.nCurrentWindow = params.m;
+ this.pastSize = window + nPastWindow;
+ this.currentSize = window + nCurrentWindow;
+ this.currentOffset = params.g;
+ this.r = params.r;
+
+ // (w + n) past samples for the n-past-windows
+ // (w + m) current samples for the m-current-windows, starting from offset g
+ // => need to hold past (w + n + g + w + m) samples from the latest sample
+ int holdSampleSize = pastSize + currentOffset + currentSize;
+
+ this.xRing = new DoubleRingBuffer(holdSampleSize);
+ this.xSeries = new double[holdSampleSize];
+ }
+
+ @Override
+ public void update(@Nonnull final Object arg, @Nonnull final double[] outScores)
+ throws HiveException {
+ double x = PrimitiveObjectInspectorUtils.getDouble(arg, oi);
+ xRing.add(x).toArray(xSeries, true /* FIFO */);
+
+ // need to wait until the buffer is filled
+ if (!xRing.isFull()) {
+ outScores[0] = 0.d;
+ } else {
+ outScores[0] = computeScore();
+ }
+ }
+
+ private double computeScore() {
+ // create past trajectory matrix and find its left singular vectors
+ RealMatrix H = MatrixUtils.createRealMatrix(window, nPastWindow);
+ for (int i = 0; i < nPastWindow; i++) {
+ H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window));
+ }
+ SingularValueDecomposition svdH = new SingularValueDecomposition(H);
+ RealMatrix UT = svdH.getUT();
+
+ // create current trajectory matrix and find its left singular vectors
+ RealMatrix G = MatrixUtils.createRealMatrix(window, nCurrentWindow);
+ int currentHead = pastSize + currentOffset;
+ for (int i = 0; i < nCurrentWindow; i++) {
+ G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window));
+ }
+ SingularValueDecomposition svdG = new SingularValueDecomposition(G);
+ RealMatrix Q = svdG.getU();
+
+ // find the largest singular value for the r principal components
+ RealMatrix UTQ = UT.getSubMatrix(0, r - 1, 0, window - 1).multiply(Q.getSubMatrix(0, window - 1, 0, r - 1));
+ SingularValueDecomposition svdUTQ = new SingularValueDecomposition(UTQ);
+ double[] s = svdUTQ.getSingularValues();
+
+ return 1.d - s[0];
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java b/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
new file mode 100644
index 0000000..3ab5ae8
--- /dev/null
+++ b/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java
@@ -0,0 +1,197 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.UDFWithOptions;
+import hivemall.utils.collections.DoubleRingBuffer;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.Primitives;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.BooleanWritable;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+@Description(
+ name = "sst_changepoint",
+ value = "_FUNC_(double|array<double> x [, const string options])"
+ + " - Returns change-point scores and decisions using Singular Spectrum Transformation (SST)."
+ + " It will return a tuple <double changepoint_score [, boolean is_changepoint]>")
+public final class SSTChangePointUDF extends UDFWithOptions {
+
+ private transient Parameters _params;
+ private transient SSTChangePoint _sst;
+
+ private transient double[] _scores;
+ private transient Object[] _result;
+ private transient DoubleWritable _changepointScore;
+ @Nullable
+ private transient BooleanWritable _isChangepoint = null;
+
+ public SSTChangePointUDF() {}
+
+ // Visible for testing
+ Parameters getParameters() {
+ return _params;
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("w", "window", true, "Number of samples which affects change-point score [default: 30]");
+ opts.addOption("n", "n_past", true,
+ "Number of past windows for change-point scoring [default: equal to `w` = 30]");
+ opts.addOption("m", "n_current", true,
+ "Number of current windows for change-point scoring [default: equal to `w` = 30]");
+ opts.addOption("g", "current_offset", true,
+ "Offset of the current windows from the updating sample [default: `-w` = -30]");
+ opts.addOption("r", "n_component", true,
+ "Number of singular vectors (i.e. principal components) [default: 3]");
+ opts.addOption("k", "n_dim", true,
+ "Number of dimensions for the Krylov subspaces [default: 5 (`2*r` if `r` is even, `2*r-1` otherwise)]");
+ opts.addOption("th", "threshold", true,
+ "Score threshold (inclusive) for determining change-point existence [default: -1, do not output decision]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(String optionValues) throws UDFArgumentException {
+ CommandLine cl = parseOptions(optionValues);
+
+ this._params.w = Primitives.parseInt(cl.getOptionValue("w"), _params.w);
+ this._params.n = Primitives.parseInt(cl.getOptionValue("n"), _params.w);
+ this._params.m = Primitives.parseInt(cl.getOptionValue("m"), _params.w);
+ this._params.g = Primitives.parseInt(cl.getOptionValue("g"), -1 * _params.w);
+ this._params.r = Primitives.parseInt(cl.getOptionValue("r"), _params.r);
+ this._params.k = Primitives.parseInt(
+ cl.getOptionValue("k"), (_params.r % 2 == 0) ? (2 * _params.r) : (2 * _params.r - 1));
+ this._params.changepointThreshold = Primitives.parseDouble(
+ cl.getOptionValue("th"), _params.changepointThreshold);
+
+ Preconditions.checkArgument(_params.w >= 2, "w must be greather than 1: " + _params.w);
+ Preconditions.checkArgument(_params.r >= 1, "r must be greater than 0: " + _params.r);
+ Preconditions.checkArgument(_params.k >= 1, "k must be greater than 0: " + _params.k);
+ Preconditions.checkArgument(_params.changepointThreshold > 0.d && _params.changepointThreshold < 1.d,
+ "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold);
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
+ throws UDFArgumentException {
+ if (argOIs.length < 1 || argOIs.length > 2) {
+ throw new UDFArgumentException(
+ "_FUNC_(double|array<double> x [, const string options]) takes 1 or 2 arguments: "
+ + Arrays.toString(argOIs));
+ }
+
+ this._params = new Parameters();
+ if (argOIs.length == 2) {
+ String options = HiveUtils.getConstString(argOIs[1]);
+ processOptions(options);
+ }
+
+ ObjectInspector argOI0 = argOIs[0];
+ PrimitiveObjectInspector xOI = HiveUtils.asDoubleCompatibleOI(argOI0);
+ this._sst = new SSTChangePoint(_params, xOI);
+
+ this._scores = new double[1];
+
+ final Object[] result;
+ final ArrayList<String> fieldNames = new ArrayList<String>();
+ final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldNames.add("changepoint_score");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ if (_params.changepointThreshold != -1d) {
+ fieldNames.add("is_changepoint");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector);
+ result = new Object[2];
+ this._isChangepoint = new BooleanWritable(false);
+ result[1] = _isChangepoint;
+ } else {
+ result = new Object[1];
+ }
+ this._changepointScore = new DoubleWritable(0.d);
+ result[0] = _changepointScore;
+ this._result = result;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public Object[] evaluate(@Nonnull DeferredObject[] args) throws HiveException {
+ Object x = args[0].get();
+ if (x == null) {
+ return _result;
+ }
+
+ _sst.update(x, _scores);
+
+ double changepointScore = _scores[0];
+ _changepointScore.set(changepointScore);
+ if (_isChangepoint != null) {
+ _isChangepoint.set(changepointScore >= _params.changepointThreshold);
+ }
+
+ return _result;
+ }
+
+ @Override
+ public void close() throws IOException {
+ this._result = null;
+ this._changepointScore = null;
+ this._isChangepoint = null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "sst(" + Arrays.toString(children) + ")";
+ }
+
+ static final class Parameters {
+ int w = 30;
+ int n = 30;
+ int m = 30;
+ int g = -30;
+ int r = 3;
+ int k = 5;
+ double changepointThreshold = -1.d;
+
+ Parameters() {}
+ }
+
+ public interface SSTChangePointInterface {
+ void update(@Nonnull Object arg, @Nonnull double[] outScores) throws HiveException;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java b/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
new file mode 100644
index 0000000..b41d474
--- /dev/null
+++ b/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java
@@ -0,0 +1,111 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.anomaly;
+
+import hivemall.anomaly.SSTChangePointUDF.Parameters;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.zip.GZIPInputStream;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class SSTChangePointTest {
+ private static final boolean DEBUG = false;
+
+ @Test
+ public void testSST() throws IOException, HiveException {
+ Parameters params = new Parameters();
+ PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ SSTChangePoint sst = new SSTChangePoint(params, oi);
+ double[] outScores = new double[1];
+
+ BufferedReader reader = readFile("cf1d.csv");
+ println("x change");
+ String line;
+ int numChangepoints = 0;
+ while ((line = reader.readLine()) != null) {
+ double x = Double.parseDouble(line);
+ sst.update(x, outScores);
+ printf("%f %f%n", x, outScores[0]);
+ if (outScores[0] > 0.95d) {
+ numChangepoints++;
+ }
+ }
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ @Test
+ public void testTwitterData() throws IOException, HiveException {
+ Parameters params = new Parameters();
+ PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+ SSTChangePoint sst = new SSTChangePoint(params, oi);
+ double[] outScores = new double[1];
+
+ BufferedReader reader = readFile("twitter.csv.gz");
+ println("# time x change");
+ String line;
+ int i = 1, numChangepoints = 0;
+ while ((line = reader.readLine()) != null) {
+ double x = Double.parseDouble(line);
+ sst.update(x, outScores);
+ printf("%d %f %f%n", i, x, outScores[0]);
+ if (outScores[0] > 0.005d) {
+ numChangepoints++;
+ }
+ i++;
+ }
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ private static void println(String msg) {
+ if (DEBUG) {
+ System.out.println(msg);
+ }
+ }
+
+ private static void printf(String format, Object... args) {
+ if (DEBUG) {
+ System.out.printf(format, args);
+ }
+ }
+
+ @Nonnull
+ private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
+ InputStream is = SSTChangePointTest.class.getResourceAsStream(fileName);
+ if (fileName.endsWith(".gz")) {
+ is = new GZIPInputStream(is);
+ }
+ return new BufferedReader(new InputStreamReader(is));
+ }
+
+}
[41/50] [abbrv] incubator-hivemall git commit: Refine access
modifiers/calls
Posted by my...@apache.org.
Refine access modifiers/calls
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/ddd8dc2d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/ddd8dc2d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/ddd8dc2d
Branch: refs/heads/JIRA-22/pr-336
Commit: ddd8dc2dbf8222c9d9d84b038dbdcd9aef1f1a87
Parents: 7447dde
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 04:22:51 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 04:22:51 2016 +0900
----------------------------------------------------------------------
.../systemtest/runner/HiveSystemTestRunner.java | 4 +-
.../systemtest/runner/SystemTestRunner.java | 40 +++++++++++---------
.../systemtest/runner/SystemTestTeam.java | 8 +---
.../systemtest/runner/TDSystemTestRunner.java | 24 ++++++------
4 files changed, 36 insertions(+), 40 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ddd8dc2d/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
index 25a2125..db1edc7 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
@@ -101,7 +101,7 @@ public class HiveSystemTestRunner extends SystemTestRunner {
}
@Override
- protected void finRunner() {
+ void finRunner() {
if (container != null) {
container.tearDown();
}
@@ -111,7 +111,7 @@ public class HiveSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> exec(@Nonnull final RawHQ hq) {
+ public List<String> exec(@Nonnull final RawHQ hq) {
logger.info("executing: `" + hq.query + "`");
return hShell.executeQuery(hq.query);
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ddd8dc2d/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
index f16da90..e142174 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
@@ -45,7 +45,6 @@ import javax.annotation.Nullable;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -55,9 +54,9 @@ import java.util.Set;
public abstract class SystemTestRunner extends ExternalResource {
static final Logger logger = LoggerFactory.getLogger(SystemTestRunner.class);
@Nonnull
- final List<HQBase> classInitHqs;
+ private final List<HQBase> classInitHqs;
@Nonnull
- final Set<String> immutableTables;
+ private final Set<String> immutableTables;
@Nonnull
final String dbName;
@Nonnull
@@ -98,7 +97,7 @@ public abstract class SystemTestRunner extends ExternalResource {
@Override
protected void after() {
try {
- resetDB(); // clean up database
+ cleanDB(); // clean up database
} catch (Exception ex) {
throw new QueryExecutionException("Failed to clean up temporary database. "
+ ex.getMessage());
@@ -111,16 +110,16 @@ public abstract class SystemTestRunner extends ExternalResource {
abstract void finRunner();
- public void initBy(@Nonnull final HQBase hq) {
+ protected void initBy(@Nonnull final HQBase hq) {
classInitHqs.add(hq);
}
- public void initBy(@Nonnull final List<? extends HQBase> hqs) {
+ protected void initBy(@Nonnull final List<? extends HQBase> hqs) {
classInitHqs.addAll(hqs);
}
// fix to temporary database and user-defined init (should be called per Test class)
- void prepareDB() throws Exception {
+ private void prepareDB() throws Exception {
createDB(dbName);
use(dbName);
for (HQBase q : classInitHqs) {
@@ -136,15 +135,21 @@ public abstract class SystemTestRunner extends ExternalResource {
}
// drop temporary database (should be called per Test class)
- void resetDB() throws Exception {
+ private void cleanDB() throws Exception {
dropDB(dbName);
}
- public final boolean isImmutableTable(final String tableName) {
- return immutableTables.contains(tableName);
+ // drop temporary tables (should be called per Test method)
+ void resetDB() throws Exception {
+ final List<String> tables = tableList();
+ for (String t : tables) {
+ if (!immutableTables.contains(t)) {
+ dropTable(HQ.dropTable(t));
+ }
+ }
}
- // execute HQBase
+ // >execute HQBase
public List<String> exec(@Nonnull final HQBase hq) throws Exception {
if (hq instanceof RawHQ) {
return exec((RawHQ) hq);
@@ -157,10 +162,10 @@ public abstract class SystemTestRunner extends ExternalResource {
}
}
- //// execute RawHQ
+ // >>execute RawHQ
abstract protected List<String> exec(@Nonnull final RawHQ hq) throws Exception;
- //// execute TableHQ
+ // >>execute TableHQ
List<String> exec(@Nonnull final TableHQ hq) throws Exception {
if (hq instanceof CreateTableHQ) {
return createTable((CreateTableHQ) hq);
@@ -175,7 +180,7 @@ public abstract class SystemTestRunner extends ExternalResource {
}
}
- ////// execute UploadFileHQ
+ // >>>execute UploadFileHQ
List<String> exec(@Nonnull final UploadFileHQ hq) throws Exception {
if (hq instanceof UploadFileAsNewTableHQ) {
return uploadFileAsNewTable((UploadFileAsNewTableHQ) hq);
@@ -187,8 +192,8 @@ public abstract class SystemTestRunner extends ExternalResource {
}
// matching HQBase
- public void matching(@Nonnull final HQBase hq, @CheckForNull final String answer,
- final boolean ordered) throws Exception {
+ void matching(@Nonnull final HQBase hq, @CheckForNull final String answer, final boolean ordered)
+ throws Exception {
Preconditions.checkNotNull(answer);
List<String> result = exec(hq);
@@ -203,8 +208,7 @@ public abstract class SystemTestRunner extends ExternalResource {
}
// matching HQBase (ordered == false)
- public void matching(@Nonnull final HQBase hq, @CheckForNull final String answer)
- throws Exception {
+ void matching(@Nonnull final HQBase hq, @CheckForNull final String answer) throws Exception {
matching(hq, answer, false);
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ddd8dc2d/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
index 86065e4..fcd2fcb 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestTeam.java
@@ -19,7 +19,6 @@
package hivemall.systemtest.runner;
import hivemall.systemtest.exception.QueryExecutionException;
-import hivemall.systemtest.model.HQ;
import hivemall.systemtest.model.HQBase;
import hivemall.systemtest.model.RawHQ;
import hivemall.systemtest.model.lazy.LazyMatchingResource;
@@ -66,12 +65,7 @@ public class SystemTestTeam extends ExternalResource {
for (SystemTestRunner runner : reachGoal) {
try {
- final List<String> tables = runner.exec(HQ.tableList());
- for (String t : tables) {
- if (!runner.isImmutableTable(t)) {
- runner.exec(HQ.dropTable(t));
- }
- }
+ runner.resetDB();
} catch (Exception ex) {
throw new QueryExecutionException("Failed to resetPerMethod database. "
+ ex.getMessage());
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/ddd8dc2d/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
index 87dd835..b2f8290 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
@@ -70,7 +70,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected void initRunner() {
+ void initRunner() {
// optional
if (props.containsKey("execFinishRetryLimit")) {
execFinishRetryLimit = Integer.valueOf(props.getProperty("execFinishRetryLimit"));
@@ -102,14 +102,14 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected void finRunner() {
+ void finRunner() {
if (client != null) {
client.close();
}
}
@Override
- protected List<String> exec(@Nonnull final RawHQ hq) throws Exception {
+ public List<String> exec(@Nonnull final RawHQ hq) throws Exception {
logger.info("executing: `" + hq.query + "`");
final TDJobRequest req = TDJobRequest.newHiveQuery(dbName, hq.query);
@@ -157,7 +157,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> createDB(@Nonnull final String dbName) throws Exception {
+ List<String> createDB(@Nonnull final String dbName) throws Exception {
logger.info("executing: create database if not exists " + dbName);
client.createDatabaseIfNotExists(dbName);
@@ -165,7 +165,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> dropDB(@Nonnull final String dbName) throws Exception {
+ List<String> dropDB(@Nonnull final String dbName) throws Exception {
logger.info("executing: drop database if exists " + dbName);
client.deleteDatabaseIfExists(dbName);
@@ -173,13 +173,13 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> use(@Nonnull final String dbName) throws Exception {
+ List<String> use(@Nonnull final String dbName) throws Exception {
return Collections.singletonList("No need to execute `USE` statement on TD, so skipped `USE "
+ dbName + "`");
}
@Override
- protected List<String> tableList() throws Exception {
+ List<String> tableList() throws Exception {
logger.info("executing: show tables on " + dbName);
final List<TDTable> tables = client.listTables(dbName);
@@ -191,7 +191,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> createTable(@Nonnull final CreateTableHQ hq) throws Exception {
+ List<String> createTable(@Nonnull final CreateTableHQ hq) throws Exception {
logger.info("executing: create table " + hq.tableName + " if not exists on " + dbName);
final List<TDColumn> columns = new ArrayList<TDColumn>();
@@ -204,7 +204,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> dropTable(@Nonnull final DropTableHQ hq) throws Exception {
+ List<String> dropTable(@Nonnull final DropTableHQ hq) throws Exception {
logger.info("executing: drop table " + hq.tableName + " if exists on " + dbName);
client.deleteTableIfExists(dbName, hq.tableName);
@@ -212,8 +212,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> uploadFileAsNewTable(@Nonnull final UploadFileAsNewTableHQ hq)
- throws Exception {
+ List<String> uploadFileAsNewTable(@Nonnull final UploadFileAsNewTableHQ hq) throws Exception {
logger.info("executing: create " + hq.tableName + " based on " + hq.file.getPath()
+ " if not exists on " + dbName);
@@ -292,8 +291,7 @@ public class TDSystemTestRunner extends SystemTestRunner {
}
@Override
- protected List<String> uploadFileToExisting(@Nonnull final UploadFileToExistingHQ hq)
- throws Exception {
+ List<String> uploadFileToExisting(@Nonnull final UploadFileToExistingHQ hq) throws Exception {
logger.info("executing: insert " + hq.file.getPath() + " into " + hq.tableName + " on "
+ dbName);
[17/50] [abbrv] incubator-hivemall git commit: refine
transpose_and_dot
Posted by my...@apache.org.
refine transpose_and_dot
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/abbf5492
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/abbf5492
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/abbf5492
Branch: refs/heads/JIRA-22/pr-385
Commit: abbf5492b95dd69e347580c59ac044a78627c547
Parents: a16a3fd
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 13:11:00 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:40:54 2016 +0900
----------------------------------------------------------------------
.../tools/matrix/TransposeAndDotUDAF.java | 32 +++++++++++---------
1 file changed, 18 insertions(+), 14 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/abbf5492/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
index 1e54004..9d68f93 100644
--- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -127,33 +127,37 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
@Override
public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
- TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
+ final TransposeAndDotAggregationBuffer myAgg = new TransposeAndDotAggregationBuffer();
reset(myAgg);
return myAgg;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
- TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
myAgg.reset();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
- TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ final Object matrix0RowObj = parameters[0];
+ final Object matrix1RowObj = parameters[1];
+ Preconditions.checkNotNull(matrix0RowObj);
+ Preconditions.checkNotNull(matrix1RowObj);
+
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+
+ // init
if (matrix0Row == null) {
- matrix0Row = new double[matrix0RowOI.getListLength(parameters[0])];
+ matrix0Row = new double[matrix0RowOI.getListLength(matrix0RowObj)];
}
if (matrix1Row == null) {
- matrix1Row = new double[matrix1RowOI.getListLength(parameters[1])];
+ matrix1Row = new double[matrix1RowOI.getListLength(matrix1RowObj)];
}
- HiveUtils.toDoubleArray(parameters[0], matrix0RowOI, matrix0ElOI, matrix0Row, false);
- HiveUtils.toDoubleArray(parameters[1], matrix1RowOI, matrix1ElOI, matrix1Row, false);
-
- Preconditions.checkNotNull(matrix0Row);
- Preconditions.checkNotNull(matrix1Row);
+ HiveUtils.toDoubleArray(matrix0RowObj, matrix0RowOI, matrix0ElOI, matrix0Row, false);
+ HiveUtils.toDoubleArray(matrix1RowObj, matrix1RowOI, matrix1ElOI, matrix1Row, false);
if (myAgg.aggMatrix == null) {
myAgg.init(matrix0Row.length, matrix1Row.length);
@@ -172,9 +176,9 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
return;
}
- TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
- List matrix = aggMatrixOI.getList(other);
+ final List matrix = aggMatrixOI.getList(other);
final int n = matrix.size();
final double[] row = new double[aggMatrixRowOI.getListLength(matrix.get(0))];
for (int i = 0; i < n; i++) {
@@ -197,9 +201,9 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
- TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
+ final TransposeAndDotAggregationBuffer myAgg = (TransposeAndDotAggregationBuffer) agg;
- List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
+ final List<List<DoubleWritable>> result = new ArrayList<List<DoubleWritable>>();
for (double[] row : myAgg.aggMatrix) {
result.add(WritableUtils.toWritableList(row));
}
[22/50] [abbrv] incubator-hivemall git commit: add snr
Posted by my...@apache.org.
add snr
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/22a608ee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/22a608ee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/22a608ee
Branch: refs/heads/JIRA-22/pr-385
Commit: 22a608ee1c7239b2953183b5341f80c58b1e7045
Parents: 5088ef3
Author: amaya <gi...@sapphire.in.net>
Authored: Mon Sep 26 17:07:55 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Mon Sep 26 17:15:22 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/SignalNoiseRatioUDAF.java | 327 +++++++++++++++++++
.../selection/SignalNoiseRatioUDAFTest.java | 174 ++++++++++
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.hive | 3 +
resources/ddl/define-all.spark | 3 +
resources/ddl/define-udfs.td.hql | 1 +
6 files changed, 511 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
new file mode 100644
index 0000000..b7b9126
--- /dev/null
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -0,0 +1,327 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.hadoop.WritableUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label)"
+ + " - Returns SNR values of each feature as array<double>")
+public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
+ @Override
+ public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
+ throws SemanticException {
+ final ObjectInspector[] OIs = info.getParameterObjectInspectors();
+
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `features`");
+ }
+
+ if (!HiveUtils.isListOI(OIs[1])
+ || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<int> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `labels`");
+ }
+
+ return new SignalNoiseRatioUDAFEvaluator();
+ }
+
+ static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
+ // PARTIAL1 and COMPLETE
+ private ListObjectInspector featuresOI;
+ private PrimitiveObjectInspector featureOI;
+ private ListObjectInspector labelsOI;
+ private PrimitiveObjectInspector labelOI;
+
+ // PARTIAL2 and FINAL
+ private StructObjectInspector structOI;
+ private StructField nsField, meanssField, variancessField;
+ private ListObjectInspector nsOI;
+ private LongObjectInspector nOI;
+ private ListObjectInspector meanssOI;
+ private ListObjectInspector meansOI;
+ private DoubleObjectInspector meanOI;
+ private ListObjectInspector variancessOI;
+ private ListObjectInspector variancesOI;
+ private DoubleObjectInspector varianceOI;
+
+ @AggregationType(estimable = true)
+ static class SignalNoiseRatioAggregationBuffer extends AbstractAggregationBuffer {
+ long[] ns;
+ double[][] meanss;
+ double[][] variancess;
+
+ @Override
+ public int estimate() {
+ return ns == null ? 0 : 8 * ns.length + 8 * meanss.length * meanss[0].length + 8
+ * variancess.length * variancess[0].length;
+ }
+
+ public void init(int nClasses, int nFeatures) {
+ ns = new long[nClasses];
+ meanss = new double[nClasses][nFeatures];
+ variancess = new double[nClasses][nFeatures];
+ }
+
+ public void reset() {
+ if (ns != null) {
+ Arrays.fill(ns, 0);
+ for (double[] means : meanss) {
+ Arrays.fill(means, 0.d);
+ }
+ for (double[] variances : variancess) {
+ Arrays.fill(variances, 0.d);
+ }
+ }
+ }
+ }
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
+ super.init(mode, OIs);
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+ featuresOI = HiveUtils.asListOI(OIs[0]);
+ featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+ labelsOI = HiveUtils.asListOI(OIs[1]);
+ labelOI = HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
+ } else {
+ structOI = (StructObjectInspector) OIs[0];
+ nsField = structOI.getStructFieldRef("ns");
+ nsOI = HiveUtils.asListOI(nsField.getFieldObjectInspector());
+ nOI = HiveUtils.asLongOI(nsOI.getListElementObjectInspector());
+ meanssField = structOI.getStructFieldRef("meanss");
+ meanssOI = HiveUtils.asListOI(meanssField.getFieldObjectInspector());
+ meansOI = HiveUtils.asListOI(meanssOI.getListElementObjectInspector());
+ meanOI = HiveUtils.asDoubleOI(meansOI.getListElementObjectInspector());
+ variancessField = structOI.getStructFieldRef("variancess");
+ variancessOI = HiveUtils.asListOI(variancessField.getFieldObjectInspector());
+ variancesOI = HiveUtils.asListOI(variancessOI.getListElementObjectInspector());
+ varianceOI = HiveUtils.asDoubleOI(variancesOI.getListElementObjectInspector());
+ }
+
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
+ List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
+ return ObjectInspectorFactory.getStandardStructObjectInspector(
+ Arrays.asList("ns", "meanss", "variancess"), fieldOIs);
+ } else {
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+ }
+
+ @Override
+ public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer();
+ reset(myAgg);
+ return myAgg;
+ }
+
+ @Override
+ public void reset(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+ myAgg.reset();
+ }
+
+ @Override
+ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
+ final Object featuresObj = parameters[0];
+ final Object labelsObj = parameters[1];
+
+ Preconditions.checkNotNull(featuresObj);
+ Preconditions.checkNotNull(labelsObj);
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ // read class
+ final List labels = labelsOI.getList(labelsObj);
+ final int nClasses = labels.size();
+
+ // to calc SNR between classes
+ Preconditions.checkArgument(nClasses >= 2);
+
+ int clazz = -1;
+ for (int i = 0; i < nClasses; i++) {
+ int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+ if (label == 1 && clazz == -1) {
+ clazz = i;
+ } else if (label == 1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array. Multiple hot elements found.");
+ }
+ }
+ if (clazz == -1) {
+ throw new UDFArgumentException(
+ "Specify one-hot vectorized array. Hot element not found.");
+ }
+
+ final List features = featuresOI.getList(featuresObj);
+ final int nFeatures = features.size();
+
+ Preconditions.checkArgument(nFeatures >= 1);
+
+ if (myAgg.ns == null) {
+ // init
+ myAgg.init(nClasses, nFeatures);
+ } else {
+ Preconditions.checkArgument(nClasses == myAgg.ns.length);
+ Preconditions.checkArgument(nFeatures == myAgg.meanss[0].length);
+ }
+
+ // calc incrementally
+ final long n = myAgg.ns[clazz];
+ myAgg.ns[clazz]++;
+ for (int i = 0; i < nFeatures; i++) {
+ final double x = PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
+ final double meanN = myAgg.meanss[clazz][i];
+ final double varianceN = myAgg.variancess[clazz][i];
+ myAgg.meanss[clazz][i] = (n * meanN + x) / (n + 1.d);
+ myAgg.variancess[clazz][i] = (n * varianceN + (x - meanN)
+ * (x - myAgg.meanss[clazz][i]))
+ / (n + 1.d);
+ }
+ }
+
+ @Override
+ public void merge(AggregationBuffer agg, Object other) throws HiveException {
+ if (other == null) {
+ return;
+ }
+
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final List ns = nsOI.getList(structOI.getStructFieldData(other, nsField));
+ final List meanss = meanssOI.getList(structOI.getStructFieldData(other, meanssField));
+ final List variancess = variancessOI.getList(structOI.getStructFieldData(other,
+ variancessField));
+
+ final int nClasses = ns.size();
+ final int nFeatures = meansOI.getListLength(meanss.get(0));
+ if (myAgg.ns == null) {
+ // init
+ myAgg.init(nClasses, nFeatures);
+ }
+ for (int i = 0; i < nClasses; i++) {
+ final long n = myAgg.ns[i];
+ final long m = PrimitiveObjectInspectorUtils.getLong(ns.get(i), nOI);
+ final List means = meansOI.getList(meanss.get(i));
+ final List variances = variancesOI.getList(variancess.get(i));
+
+ myAgg.ns[i] += m;
+ for (int j = 0; j < nFeatures; j++) {
+ final double meanN = myAgg.meanss[i][j];
+ final double meanM = PrimitiveObjectInspectorUtils.getDouble(means.get(j),
+ meanOI);
+ final double varianceN = myAgg.variancess[i][j];
+ final double varianceM = PrimitiveObjectInspectorUtils.getDouble(
+ variances.get(j), varianceOI);
+ myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m);
+ myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + FastMath.pow(
+ meanN - meanM, 2) * n * m / (n + m))
+ / (n + m - 1);
+ }
+ }
+ }
+
+ @Override
+ public Object terminatePartial(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final Object[] partialResult = new Object[3];
+ partialResult[0] = WritableUtils.toWritableList(myAgg.ns);
+ final List<List<DoubleWritable>> meanss = new ArrayList<List<DoubleWritable>>();
+ for (double[] means : myAgg.meanss) {
+ meanss.add(WritableUtils.toWritableList(means));
+ }
+ partialResult[1] = meanss;
+ final List<List<DoubleWritable>> variancess = new ArrayList<List<DoubleWritable>>();
+ for (double[] variances : myAgg.variancess) {
+ variancess.add(WritableUtils.toWritableList(variances));
+ }
+ partialResult[2] = variancess;
+ return partialResult;
+ }
+
+ @Override
+ public Object terminate(AggregationBuffer agg) throws HiveException {
+ final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
+
+ final int nClasses = myAgg.ns.length;
+ final int nFeatures = myAgg.meanss[0].length;
+
+ // calc SNR between classes each feature
+ final double[] result = new double[nFeatures];
+ final double[] sds = new double[nClasses]; // memo
+ for (int i = 0; i < nFeatures; i++) {
+ sds[0] = FastMath.sqrt(myAgg.variancess[0][i]);
+ for (int j = 1; j < nClasses; j++) {
+ sds[j] = FastMath.sqrt(myAgg.variancess[j][i]);
+ if (Double.isNaN(sds[j])) {
+ continue;
+ }
+ for (int k = 0; k < j; k++) {
+ if (Double.isNaN(sds[k])) {
+ continue;
+ }
+ result[i] += FastMath.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i])
+ / (sds[j] + sds[k]);
+ }
+ }
+ }
+
+ // SUM(snr) GROUP BY feature
+ return WritableUtils.toWritableList(result);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
new file mode 100644
index 0000000..4655545
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -0,0 +1,174 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class SignalNoiseRatioUDAFTest {
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void test() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+ final double[] answer = new double[] {8.431818181818192, 1.3212121212121217,
+ 42.94949494949499, 33.80952380952378};
+ Assert.assertArrayEquals(answer, result, 0.d);
+ }
+
+ @Test
+ public void shouldFail0() throws Exception {
+ expectedException.expect(UDFArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {0, 0, 0}, // cause UDFArgumentException
+ {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test
+ public void shouldFail1() throws Exception {
+ expectedException.expect(IllegalArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2},
+ {4.9, 3.d, 1.4}, // cause IllegalArgumentException
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 0, 1}, {0, 0, 1}};
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+
+ @Test
+ public void shouldFail2() throws Exception {
+ expectedException.expect(IllegalArgumentException.class);
+
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] featuress = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.4, 3.2, 4.5, 1.5}, {6.3, 3.3, 6.d, 2.5},
+ {5.8, 2.7, 5.1, 1.9}};
+
+ final int[][] labelss = new int[][] { {1}, {1}, {1}, {1}, {1}, {1}}; // cause IllegalArgumentException
+
+ for (int i = 0; i < featuress.length; i++) {
+ final List<IntWritable> labels = new ArrayList<IntWritable>();
+ for (int label : labelss[i]) {
+ labels.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg,
+ new Object[] {WritableUtils.toWritableList(featuress[i]), labels});
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index b515b24..10e72b7 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -209,6 +209,9 @@ CREATE FUNCTION l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF' USIN
DROP FUNCTION IF EXISTS chi2;
CREATE FUNCTION chi2 as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS snr;
+CREATE FUNCTION snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF' USING JAR '${hivemall_jar}';
+
--------------------
-- misc functions --
--------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 2124892..04b519e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -205,6 +205,9 @@ create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2Normalizatio
drop temporary function chi2;
create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+drop temporary function snr;
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
+
-----------------------------------
-- Feature engineering functions --
-----------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 47f0ce5..65c2346 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -190,6 +190,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 'hivemall.ftvec.scaling.L
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi2")
sqlContext.sql("CREATE TEMPORARY FUNCTION chi2 AS 'hivemall.ftvec.selection.ChiSquareUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS snr")
+sqlContext.sql("CREATE TEMPORARY FUNCTION snr AS 'hivemall.ftvec.selection.SignalNoiseRatioUDAF'")
+
/**
* misc functions
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/22a608ee/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index fd7dc1d..7aa537a 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -51,6 +51,7 @@ create temporary function rescale as 'hivemall.ftvec.scaling.RescaleUDF';
create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
+create temporary function snr as 'hivemall.ftvec.selection.SignalNoiseRatioUDAF';
create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifierUDTF';
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';
[49/50] [abbrv] incubator-hivemall git commit: Merge branch
'feature/kernelized_pa' of https://github.com/L3Sota/hivemall into
JIRA-22/pr-304
Posted by my...@apache.org.
Merge branch 'feature/kernelized_pa' of https://github.com/L3Sota/hivemall into JIRA-22/pr-304
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b0a0179b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b0a0179b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b0a0179b
Branch: refs/heads/JIRA-22/pr-304
Commit: b0a0179b0bc1f50403eb1f5534bfb870113f9777
Parents: 72d6a62 f986803
Author: myui <yu...@gmail.com>
Authored: Fri Dec 2 16:55:49 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Fri Dec 2 16:55:49 2016 +0900
----------------------------------------------------------------------
.../KernelizedPassiveAggressiveUDTF.java | 396 +++++++++++++++
.../main/java/hivemall/model/FeatureValue.java | 19 +-
.../utils/collections/FloatArrayList.java | 151 ++++++
.../KernelizedPassiveAggressiveUDTFTest.java | 485 ++++++++++++++++++
.../hivemall/classifier/news20-medium.binary | 500 +++++++++++++++++++
.../hivemall/classifier/news20-small.binary | 100 ++++
.../hivemall/classifier/news20-tiny.binary | 2 +
7 files changed, 1652 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b0a0179b/core/src/main/java/hivemall/model/FeatureValue.java
----------------------------------------------------------------------
[05/50] [abbrv] incubator-hivemall git commit: add
HiveUtils.isNumberListListOI
Posted by my...@apache.org.
add HiveUtils.isNumberListListOI
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/d0e97e6f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/d0e97e6f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/d0e97e6f
Branch: refs/heads/JIRA-22/pr-385
Commit: d0e97e6ff71b2072ec5235cc3ac169162d59da59
Parents: d8f1005
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 12:02:28 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 12:02:28 2016 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/utils/hadoop/HiveUtils.java | 4 ++++
1 file changed, 4 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d0e97e6f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 7e8ea7b..dcbf534 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -235,6 +235,10 @@ public final class HiveUtils {
return isListOI(oi) && isNumberOI(((ListObjectInspector)oi).getListElementObjectInspector());
}
+ public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi) && isNumberListOI(((ListObjectInspector)oi).getListElementObjectInspector());
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
[15/50] [abbrv] incubator-hivemall git commit: standardize to chi2
Posted by my...@apache.org.
standardize to chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/6dc23449
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/6dc23449
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/6dc23449
Branch: refs/heads/JIRA-22/pr-385
Commit: 6dc234490dc25f563b22e5659c378e6ebcf8dcdb
Parents: 89c81aa
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 11:41:59 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:35:23 2016 +0900
----------------------------------------------------------------------
resources/ddl/define-all-as-permanent.hive | 4 ++--
resources/ddl/define-all.hive | 4 ++--
resources/ddl/define-all.spark | 4 ++--
resources/ddl/define-udfs.td.hql | 2 +-
4 files changed, 7 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6dc23449/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index adf6a14..b515b24 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -206,8 +206,8 @@ CREATE FUNCTION l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF' USIN
-- selection functions --
-------------------------
-DROP FUNCTION IF EXISTS chi_square;
-CREATE FUNCTION chi_square as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS chi2;
+CREATE FUNCTION chi2 as 'hivemall.ftvec.selection.ChiSquareUDF' USING JAR '${hivemall_jar}';
--------------------
-- misc functions --
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6dc23449/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 1586d2e..2124892 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -202,8 +202,8 @@ create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2Normalizatio
-- selection functions --
-------------------------
-drop temporary function chi_square;
-create temporary function chi_square as 'hivemall.ftvec.selection.ChiSquareUDF';
+drop temporary function chi2;
+create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
-----------------------------------
-- Feature engineering functions --
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6dc23449/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 50d560b..47f0ce5 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -187,8 +187,8 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION normalize AS 'hivemall.ftvec.scaling.L
* selection functions
*/
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi_square")
-sqlContext.sql("CREATE TEMPORARY FUNCTION chi_square AS 'hivemall.ftvec.selection.ChiSquareUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS chi2")
+sqlContext.sql("CREATE TEMPORARY FUNCTION chi2 AS 'hivemall.ftvec.selection.ChiSquareUDF'")
/**
* misc functions
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/6dc23449/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index 601eead..fd7dc1d 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -50,7 +50,7 @@ create temporary function powered_features as 'hivemall.ftvec.pairing.PoweredFea
create temporary function rescale as 'hivemall.ftvec.scaling.RescaleUDF';
create temporary function zscore as 'hivemall.ftvec.scaling.ZScoreUDF';
create temporary function l2_normalize as 'hivemall.ftvec.scaling.L2NormalizationUDF';
-create temporary function chi_square as 'hivemall.ftvec.selection.ChiSquareUDF';
+create temporary function chi2 as 'hivemall.ftvec.selection.ChiSquareUDF';
create temporary function amplify as 'hivemall.ftvec.amplify.AmplifierUDTF';
create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifierUDTF';
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';
[14/50] [abbrv] incubator-hivemall git commit: change to select_k_best
Posted by my...@apache.org.
change to select_k_best
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/89c81aac
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/89c81aac
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/89c81aac
Branch: refs/heads/JIRA-22/pr-385
Commit: 89c81aacf5b13f6e125723cb5c703333574c10ae
Parents: be1ea37
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 10:56:59 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 13:35:16 2016 +0900
----------------------------------------------------------------------
.../tools/array/ArrayTopKIndicesUDF.java | 115 ---------------
.../hivemall/tools/array/SelectKBestUDF.java | 143 +++++++++++++++++++
.../tools/array/SubarrayByIndicesUDF.java | 111 --------------
resources/ddl/define-all-as-permanent.hive | 9 +-
resources/ddl/define-all.hive | 9 +-
resources/ddl/define-all.spark | 7 +-
resources/ddl/define-udfs.td.hql | 3 +-
7 files changed, 152 insertions(+), 245 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
deleted file mode 100644
index f895f9b..0000000
--- a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.tools.array;
-
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.Preconditions;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-import org.apache.hadoop.io.IntWritable;
-
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-
-@Description(
- name = "array_top_k_indices",
- value = "_FUNC_(array<number> array, const int k) - Returns indices array of top-k as array<int>")
-public class ArrayTopKIndicesUDF extends GenericUDF {
- private ListObjectInspector arrayOI;
- private PrimitiveObjectInspector elementOI;
- private PrimitiveObjectInspector kOI;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
- if (OIs.length != 2) {
- throw new UDFArgumentLengthException("Specify two or three arguments.");
- }
-
- if (!HiveUtils.isNumberListOI(OIs[0])) {
- throw new UDFArgumentTypeException(0,
- "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
- + " was passed as `array`");
- }
- if (!HiveUtils.isIntegerOI(OIs[1])) {
- throw new UDFArgumentTypeException(1, "Only int type argument is acceptable but "
- + OIs[1].getTypeName() + " was passed as `k`");
- }
-
- arrayOI = HiveUtils.asListOI(OIs[0]);
- elementOI = HiveUtils.asDoubleCompatibleOI(arrayOI.getListElementObjectInspector());
- kOI = HiveUtils.asIntegerOI(OIs[1]);
-
- return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
- }
-
- @Override
- public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
- final double[] array = HiveUtils.asDoubleArray(dObj[0].get(), arrayOI, elementOI);
- final int k = PrimitiveObjectInspectorUtils.getInt(dObj[1].get(), kOI);
-
- Preconditions.checkNotNull(array);
- Preconditions.checkArgument(array.length >= k);
-
- List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
- for (int i = 0; i < array.length; i++) {
- list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, array[i]));
- }
- list.sort(new Comparator<Map.Entry<Integer, Double>>() {
- @Override
- public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
- return o1.getValue() > o2.getValue() ? -1 : 1;
- }
- });
-
- List<IntWritable> result = new ArrayList<IntWritable>();
- for (int i = 0; i < k; i++) {
- result.add(new IntWritable(list.get(i).getKey()));
- }
- return result;
- }
-
- @Override
- public String getDisplayString(String[] children) {
- StringBuilder sb = new StringBuilder();
- sb.append("array_top_k_indices");
- sb.append("(");
- if (children.length > 0) {
- sb.append(children[0]);
- for (int i = 1; i < children.length; i++) {
- sb.append(", ");
- sb.append(children[i]);
- }
- }
- sb.append(")");
- return sb.toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
new file mode 100644
index 0000000..bdab5bb
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/SelectKBestUDF.java
@@ -0,0 +1,143 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Description(name = "select_k_best",
+ value = "_FUNC_(array<number> array, const array<number> importance_list, const int k)"
+ + " - Returns selected top-k elements as array<double>")
+public class SelectKBestUDF extends GenericUDF {
+ private ListObjectInspector featuresOI;
+ private PrimitiveObjectInspector featureOI;
+ private ListObjectInspector importanceListOI;
+ private PrimitiveObjectInspector importanceOI;
+ private PrimitiveObjectInspector kOI;
+
+ private int[] topKIndices;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 3) {
+ throw new UDFArgumentLengthException("Specify three arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0,
+ "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
+ + " was passed as `features`");
+ }
+ if (!HiveUtils.isNumberListOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1,
+ "Only array<number> type argument is acceptable but " + OIs[1].getTypeName()
+ + " was passed as `importance_list`");
+ }
+ if (!HiveUtils.isIntegerOI(OIs[2])) {
+ throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but "
+ + OIs[2].getTypeName() + " was passed as `k`");
+ }
+
+ featuresOI = HiveUtils.asListOI(OIs[0]);
+ featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
+ importanceListOI = HiveUtils.asListOI(OIs[1]);
+ importanceOI = HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector());
+ kOI = HiveUtils.asIntegerOI(OIs[2]);
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), featuresOI, featureOI);
+ final double[] importanceList = HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI,
+ importanceOI);
+ final int k = PrimitiveObjectInspectorUtils.getInt(dObj[2].get(), kOI);
+
+ Preconditions.checkNotNull(features);
+ Preconditions.checkNotNull(importanceList);
+ Preconditions.checkArgument(features.length == importanceList.length);
+ Preconditions.checkArgument(features.length >= k);
+
+ if (topKIndices == null) {
+ final List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
+ for (int i = 0; i < importanceList.length; i++) {
+ list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, importanceList[i]));
+ }
+ Collections.sort(list, new Comparator<Map.Entry<Integer, Double>>() {
+ @Override
+ public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
+ return o1.getValue() > o2.getValue() ? -1 : 1;
+ }
+ });
+
+ topKIndices = new int[k];
+ for (int i = 0; i < k; i++) {
+ topKIndices[i] = list.get(i).getKey();
+ }
+ }
+
+ final List<DoubleWritable> result = new ArrayList<DoubleWritable>();
+ for (int idx : topKIndices) {
+ result.add(new DoubleWritable(features[idx]));
+ }
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("select_k_best");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java b/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
deleted file mode 100644
index 07e158a..0000000
--- a/core/src/main/java/hivemall/tools/array/SubarrayByIndicesUDF.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Hivemall: Hive scalable Machine Learning Library
- *
- * Copyright (C) 2016 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package hivemall.tools.array;
-
-import hivemall.utils.hadoop.HiveUtils;
-import hivemall.utils.lang.Preconditions;
-import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
-import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
-
-import java.util.ArrayList;
-import java.util.List;
-
-@Description(name = "subarray_by_indices",
- value = "_FUNC_(array<number> input, array<int> indices)"
- + " - Returns subarray selected by given indices as array<number>")
-public class SubarrayByIndicesUDF extends GenericUDF {
- private ListObjectInspector inputOI;
- private PrimitiveObjectInspector elementOI;
- private ListObjectInspector indicesOI;
- private PrimitiveObjectInspector indexOI;
-
- @Override
- public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
- if (OIs.length != 2) {
- throw new UDFArgumentLengthException("Specify two arguments.");
- }
-
- if (!HiveUtils.isListOI(OIs[0])) {
- throw new UDFArgumentTypeException(0,
- "Only array<number> type argument is acceptable but " + OIs[0].getTypeName()
- + " was passed as `input`");
- }
- if (!HiveUtils.isListOI(OIs[1])
- || !HiveUtils.isIntegerOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
- throw new UDFArgumentTypeException(0,
- "Only array<int> type argument is acceptable but " + OIs[0].getTypeName()
- + " was passed as `indices`");
- }
-
- inputOI = HiveUtils.asListOI(OIs[0]);
- elementOI = HiveUtils.asDoubleCompatibleOI(inputOI.getListElementObjectInspector());
- indicesOI = HiveUtils.asListOI(OIs[1]);
- indexOI = HiveUtils.asIntegerOI(indicesOI.getListElementObjectInspector());
-
- return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
- }
-
- @Override
- public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
- final double[] input = HiveUtils.asDoubleArray(dObj[0].get(), inputOI, elementOI);
- final List indices = indicesOI.getList(dObj[1].get());
-
- Preconditions.checkNotNull(input);
- Preconditions.checkNotNull(indices);
-
- List<DoubleWritable> result = new ArrayList<DoubleWritable>();
- for (Object indexObj : indices) {
- int index = PrimitiveObjectInspectorUtils.getInt(indexObj, indexOI);
- if (index > input.length - 1) {
- throw new ArrayIndexOutOfBoundsException(index);
- }
-
- result.add(new DoubleWritable(input[index]));
- }
-
- return result;
- }
-
- @Override
- public String getDisplayString(String[] children) {
- StringBuilder sb = new StringBuilder();
- sb.append("subarray_by_indices");
- sb.append("(");
- if (children.length > 0) {
- sb.append(children[0]);
- for (int i = 1; i < children.length; i++) {
- sb.append(", ");
- sb.append(children[i]);
- }
- }
- sb.append(")");
- return sb.toString();
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index 52b73a0..adf6a14 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -371,9 +371,6 @@ CREATE FUNCTION subarray_endwith as 'hivemall.tools.array.SubarrayEndWithUDF' US
DROP FUNCTION IF EXISTS subarray_startwith;
CREATE FUNCTION subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF' USING JAR '${hivemall_jar}';
-DROP FUNCTION IF EXISTS subarray_by_indices;
-CREATE FUNCTION subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF' USING JAR '${hivemall_jar}';
-
DROP FUNCTION IF EXISTS array_concat;
CREATE FUNCTION array_concat as 'hivemall.tools.array.ArrayConcatUDF' USING JAR '${hivemall_jar}';
@@ -390,15 +387,15 @@ CREATE FUNCTION array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF' USING JA
DROP FUNCTION IF EXISTS array_sum;
CREATE FUNCTION array_sum as 'hivemall.tools.array.ArraySumUDAF' USING JAR '${hivemall_jar}';
-DROP FUNCTION array_top_k_indices;
-CREATE FUNCTION array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF' USING JAR '${hivemall_jar}';
-
DROP FUNCTION IF EXISTS to_string_array;
CREATE FUNCTION to_string_array as 'hivemall.tools.array.ToStringArrayUDF' USING JAR '${hivemall_jar}';
DROP FUNCTION IF EXISTS array_intersect;
CREATE FUNCTION array_intersect as 'hivemall.tools.array.ArrayIntersectUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS select_k_best;
+CREATE FUNCTION select_k_best as 'hivemall.tools.array.SelectKBestUDF' USING JAR '${hivemall_jar}';
+
-----------------------------
-- bit operation functions --
-----------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/resources/ddl/define-all.hive
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index a70ae0f..1586d2e 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -367,9 +367,6 @@ create temporary function subarray_endwith as 'hivemall.tools.array.SubarrayEndW
drop temporary function subarray_startwith;
create temporary function subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF';
-drop temporary function subarray_by_indices;
-create temporary function subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF';
-
drop temporary function array_concat;
create temporary function array_concat as 'hivemall.tools.array.ArrayConcatUDF';
@@ -386,15 +383,15 @@ create temporary function array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF
drop temporary function array_sum;
create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
-drop temporary function array_top_k_indices;
-create temporary function array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF';
-
drop temporary function to_string_array;
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';
drop temporary function array_intersect;
create temporary function array_intersect as 'hivemall.tools.array.ArrayIntersectUDF';
+drop temporary function select_k_best;
+create temporary function select_k_best as 'hivemall.tools.array.SelectKBestUDF';
+
-----------------------------
-- bit operation functions --
-----------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/resources/ddl/define-all.spark
----------------------------------------------------------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index e009511..50d560b 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -316,9 +316,6 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_endwith AS 'hivemall.tools.ar
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS subarray_startwith")
sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_startwith AS 'hivemall.tools.array.SubarrayStartWithUDF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS subarray_by_indices")
-sqlContext.sql("CREATE TEMPORARY FUNCTION subarray_by_indices AS 'hivemall.tools.array.SubarrayByIndicesUDF'")
-
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS collect_all")
sqlContext.sql("CREATE TEMPORARY FUNCTION collect_all AS 'hivemall.tools.array.CollectAllUDAF'")
@@ -331,8 +328,8 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION subarray AS 'hivemall.tools.array.Suba
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_avg")
sqlContext.sql("CREATE TEMPORARY FUNCTION array_avg AS 'hivemall.tools.array.ArrayAvgGenericUDAF'")
-sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS array_top_k_indices")
-sqlContext.sql("CREATE TEMPORARY FUNCTION array_top_k_indices AS 'hivemall.tools.array.ArrayTopKIndicesUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS select_k_best")
+sqlContext.sql("CREATE TEMPORARY FUNCTION select_k_best AS 'hivemall.tools.array.SelectKBestUDF'")
/**
* compression functions
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/89c81aac/resources/ddl/define-udfs.td.hql
----------------------------------------------------------------------
diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql
index 92e4003..601eead 100644
--- a/resources/ddl/define-udfs.td.hql
+++ b/resources/ddl/define-udfs.td.hql
@@ -95,14 +95,13 @@ create temporary function array_remove as 'hivemall.tools.array.ArrayRemoveUDF';
create temporary function sort_and_uniq_array as 'hivemall.tools.array.SortAndUniqArrayUDF';
create temporary function subarray_endwith as 'hivemall.tools.array.SubarrayEndWithUDF';
create temporary function subarray_startwith as 'hivemall.tools.array.SubarrayStartWithUDF';
-create temporary function subarray_by_indices as 'hivemall.tools.array.SubarrayByIndicesUDF';
create temporary function array_concat as 'hivemall.tools.array.ArrayConcatUDF';
create temporary function subarray as 'hivemall.tools.array.SubarrayUDF';
create temporary function array_avg as 'hivemall.tools.array.ArrayAvgGenericUDAF';
create temporary function array_sum as 'hivemall.tools.array.ArraySumUDAF';
-create temporary function array_top_k_indices as 'hivemall.tools.array.ArrayTopKIndicesUDF';
create temporary function to_string_array as 'hivemall.tools.array.ToStringArrayUDF';
create temporary function array_intersect as 'hivemall.tools.array.ArrayIntersectUDF';
+create temporary function select_k_best as 'hivemall.tools.array.SelectKBestUDF';
create temporary function bits_collect as 'hivemall.tools.bits.BitsCollectUDAF';
create temporary function to_bits as 'hivemall.tools.bits.ToBitsUDF';
create temporary function unbits as 'hivemall.tools.bits.UnBitsUDF';
[26/50] [abbrv] incubator-hivemall git commit: Add references for the
original SST papers
Posted by my...@apache.org.
Add references for the original SST papers
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/2bfd1270
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/2bfd1270
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/2bfd1270
Branch: refs/heads/JIRA-22/pr-356
Commit: 2bfd1270b1e9b79185a41cbe2568f2ce968d4a71
Parents: bde06e0
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Wed Sep 28 11:16:56 2016 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Wed Sep 28 11:22:46 2016 +0900
----------------------------------------------------------------------
.../hivemall/anomaly/SingularSpectrumTransformUDF.java | 11 +++++++++++
1 file changed, 11 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/2bfd1270/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
index 2ec0a91..64b7d20 100644
--- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
@@ -41,6 +41,17 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+/**
+ * Change-point detection based on Singular Spectrum Transformation (SST).
+ *
+ * References:
+ * <ul>
+ * <li>T. Ide and K. Inoue,
+ * "Knowledge Discovery from Heterogeneous Dynamic Systems using Change-Point Correlations", SDM'05.</li>
+ * <li>T. Ide and K. Tsuda, "Change-point detection using Krylov subspace learning", SDM'07.</li>
+ * </ul>
+ */
+
@Description(
name = "sst",
value = "_FUNC_(double|array<double> x [, const string options])"
[32/50] [abbrv] incubator-hivemall git commit: minor fix
Posted by my...@apache.org.
minor fix
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8d9f0d4c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8d9f0d4c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8d9f0d4c
Branch: refs/heads/JIRA-22/pr-385
Commit: 8d9f0d4c00758324029d342eb4b892e046ca4a49
Parents: 80be81e
Author: amaya <gi...@sapphire.in.net>
Authored: Thu Sep 29 11:02:14 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Thu Sep 29 11:02:14 2016 +0900
----------------------------------------------------------------------
.../test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8d9f0d4c/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 7b62b92..fe73a1b 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -743,8 +743,8 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
.toDF("c0", "arg0", "arg1")
- df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() shouldEqual
- Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))
+ checkAnswer(df0.groupby($"c0").transpose_and_dot("arg0", "arg1"),
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}
[07/50] [abbrv] incubator-hivemall git commit: add array_top_k_indices
Posted by my...@apache.org.
add array_top_k_indices
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e9d1a94f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e9d1a94f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e9d1a94f
Branch: refs/heads/JIRA-22/pr-385
Commit: e9d1a94f29f31e2910a54add7c2625825d715318
Parents: 7b07e4a
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 16:55:57 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 18:37:38 2016 +0900
----------------------------------------------------------------------
.../tools/array/ArrayTopKIndicesUDF.java | 96 ++++++++++++++++++++
1 file changed, 96 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9d1a94f/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
new file mode 100644
index 0000000..bf9fe15
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/array/ArrayTopKIndicesUDF.java
@@ -0,0 +1,96 @@
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.IntWritable;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+@Description(name = "array_top_k_indices",
+ value = "_FUNC_(array<number> array, const int k) - Returns indices array of top-k as array<int>")
+public class ArrayTopKIndicesUDF extends GenericUDF {
+ private ListObjectInspector arrayOI;
+ private PrimitiveObjectInspector elementOI;
+ private PrimitiveObjectInspector kOI;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
+ if (OIs.length != 2) {
+ throw new UDFArgumentLengthException("Specify two or three arguments.");
+ }
+
+ if (!HiveUtils.isNumberListOI(OIs[0])) {
+ throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but "
+ + OIs[0].getTypeName() + " was passed as `array`");
+ }
+ if (!HiveUtils.isIntegerOI(OIs[1])) {
+ throw new UDFArgumentTypeException(1, "Only int type argument is acceptable but "
+ + OIs[1].getTypeName() + " was passed as `k`");
+ }
+
+ arrayOI = HiveUtils.asListOI(OIs[0]);
+ elementOI = HiveUtils.asDoubleCompatibleOI(arrayOI.getListElementObjectInspector());
+ kOI = HiveUtils.asIntegerOI(OIs[1]);
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ @Override
+ public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException {
+ final double[] array = HiveUtils.asDoubleArray(dObj[0].get(), arrayOI, elementOI);
+ final int k = PrimitiveObjectInspectorUtils.getInt(dObj[1].get(), kOI);
+
+ Preconditions.checkNotNull(array);
+ Preconditions.checkArgument(array.length >= k);
+
+ List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>();
+ for (int i = 0; i < array.length; i++) {
+ list.add(new AbstractMap.SimpleEntry<Integer, Double>(i, array[i]));
+ }
+ list.sort(new Comparator<Map.Entry<Integer, Double>>() {
+ @Override
+ public int compare(Map.Entry<Integer, Double> o1, Map.Entry<Integer, Double> o2) {
+ return o1.getValue() > o2.getValue() ? -1 : 1;
+ }
+ });
+
+ List<IntWritable> result = new ArrayList<IntWritable>();
+ for (int i = 0; i < k; i++) {
+ result.add(new IntWritable(list.get(i).getKey()));
+ }
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("array_top_k_indices");
+ sb.append("(");
+ if (children.length > 0) {
+ sb.append(children[0]);
+ for (int i = 1; i < children.length; i++) {
+ sb.append(", ");
+ sb.append(children[i]);
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+}
[24/50] [abbrv] incubator-hivemall git commit: integrate chi2 and SNR
into hivemall.spark
Posted by my...@apache.org.
integrate chi2 and SNR into hivemall.spark
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a1f8f958
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a1f8f958
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a1f8f958
Branch: refs/heads/JIRA-22/pr-385
Commit: a1f8f958c99f3cde9e48b6d80d364004f6d98cc2
Parents: 22a608e
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 27 15:58:33 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 27 15:58:33 2016 +0900
----------------------------------------------------------------------
.../apache/spark/sql/hive/GroupedDataEx.scala | 24 ++++++++
.../org/apache/spark/sql/hive/HivemallOps.scala | 19 ++++++
.../spark/sql/hive/HivemallOpsSuite.scala | 63 ++++++++++++++++++-
.../org/apache/spark/sql/hive/HivemallOps.scala | 20 ++++++
.../sql/hive/RelationalGroupedDatasetEx.scala | 26 ++++++++
.../spark/sql/hive/HivemallOpsSuite.scala | 65 +++++++++++++++++++-
6 files changed, 212 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index 37d5423..2482c62 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@ -264,4 +264,28 @@ final class GroupedDataEx protected[sql](
.toAggregateExpression()
toDF((Alias(udaf, udaf.prettyString)() :: Nil).toSeq)
}
+
+ /**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
+
+ /**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 133f1d5..5970b83 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1006,6 +1006,15 @@ object HivemallOps {
}
/**
+ * @see hivemall.ftvec.selection.ChiSquareUDF
+ * @group ftvec.selection
+ */
+ def chi2(exprs: Column*): Column = {
+ HiveGenericUDF(new HiveFunctionWrapper(
+ "hivemall.ftvec.selection.ChiSquareUDF"), exprs.map(_.expr))
+ }
+
+ /**
* @see hivemall.ftvec.conv.ToDenseFeaturesUDF
* @group ftvec.conv
*/
@@ -1078,6 +1087,16 @@ object HivemallOps {
}
/**
+ * @see hivemall.tools.array.SelectKBestUDF
+ * @group tools.array
+ */
+ @scala.annotation.varargs
+ def select_k_best(exprs: Column*): Column = {
+ HiveGenericUDF(new HiveFunctionWrapper(
+ "hivemall.tools.array.SelectKBestUDF"), exprs.map(_.expr))
+ }
+
+ /**
* @see hivemall.tools.math.SigmoidUDF
* @group misc
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 4be1e5e..148e5a2 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.hive
-import scala.collection.mutable.Seq
-
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
@@ -188,6 +186,22 @@ final class HivemallOpsSuite extends HivemallQueryTest {
Row(Seq("1:1.0"))))
}
+ test("ftvec.selection - chi2") {
+ import hiveContext.implicits._
+
+ val df = Seq(Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))).toDF("arg0", "arg1")
+
+ assert(df.select(chi2(df("arg0"), df("arg1"))).collect.toSet ===
+ Set(Row(Row(Seq(10.817820878493995, 3.5944990176817315, 116.16984746363957, 67.24482558215503),
+ Seq(0.004476514990225833, 0.16575416718561453, 0d, 2.55351295663786e-15)))))
+ }
+
test("ftvec.conv - quantify") {
import hiveContext.implicits._
val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
@@ -340,6 +354,18 @@ final class HivemallOpsSuite extends HivemallQueryTest {
checkAnswer(predicted, Seq(Row(0), Row(1)))
}
+ test("tools.array - select_k_best") {
+ import hiveContext.implicits._
+
+ val data = Seq(Tuple1(Seq(0, 1, 3)), Tuple1(Seq(2, 4, 1)), Tuple1(Seq(5, 4, 9)))
+ val importance = Seq(3, 1, 2)
+ val k = 2
+ val df = data.toDF("features")
+
+ assert(df.select(select_k_best(df("features"), importance, k)).collect.toSeq ===
+ data.map(s => Row(Seq(s._1(0).toDouble, s._1(2).toDouble))))
+ }
+
test("misc - sigmoid") {
import hiveContext.implicits._
/**
@@ -536,4 +562,37 @@ final class HivemallOpsSuite extends HivemallQueryTest {
val row4 = df4.groupby($"c0").f1score("c1", "c2").collect
assert(row4(0).getDouble(1) ~== 0.25)
}
+
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF.as("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ assert(row0(0).getAs[Seq[Double]](1) ===
+ Seq(8.431818181818192, 1.3212121212121217, 42.94949494949499, 33.80952380952378))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))).toDF.as("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect
+ assert(row0(0).getAs[Seq[Double]](1) === Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 4a583db..e9a1aeb 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1228,6 +1228,16 @@ object HivemallOps {
}
/**
+ * @see hivemall.ftvec.selection.ChiSquareUDF
+ * @group ftvec.selection
+ */
+ def chi2(exprs: Column*): Column = withExpr {
+ HiveGenericUDF("chi2",
+ new HiveFunctionWrapper("hivemall.ftvec.selection.ChiSquareUDF"),
+ exprs.map(_.expr))
+ }
+
+ /**
* @see hivemall.ftvec.conv.ToDenseFeaturesUDF
* @group ftvec.conv
*/
@@ -1307,6 +1317,16 @@ object HivemallOps {
}
/**
+ * @see hivemall.tools.array.SelectKBestUDF
+ * @group tools.array
+ */
+ def select_k_best(exprs: Column*): Column = withExpr {
+ HiveGenericUDF("select_k_best",
+ new HiveFunctionWrapper("hivemall.tools.array.SelectKBestUDF"),
+ exprs.map(_.expr))
+ }
+
+ /**
* @see hivemall.tools.math.SigmoidUDF
* @group misc
*/
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/RelationalGroupedDatasetEx.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/RelationalGroupedDatasetEx.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/RelationalGroupedDatasetEx.scala
index e365197..be0673f 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/RelationalGroupedDatasetEx.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/RelationalGroupedDatasetEx.scala
@@ -274,4 +274,30 @@ final class RelationalGroupedDatasetEx protected[sql](
.toAggregateExpression()
toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
}
+
+ /**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "snr",
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "transpose_and_dot",
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a1f8f958/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 99cb1a7..039a492 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -17,9 +17,7 @@
package org.apache.spark.sql.hive
-import scala.collection.mutable.Seq
-
-import org.apache.spark.sql.{Column, Row}
+import org.apache.spark.sql.{AnalysisException, Column, Row}
import org.apache.spark.sql.functions
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
@@ -189,6 +187,22 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
Row(Seq("1:1.0"))))
}
+ test("ftvec.selection - chi2") {
+ import hiveContext.implicits._
+
+ val df = Seq(Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))).toDF("arg0", "arg1")
+
+ assert(df.select(chi2(df("arg0"), df("arg1"))).collect.toSet ===
+ Set(Row(Row(Seq(10.817820878493995, 3.5944990176817315, 116.16984746363957, 67.24482558215503),
+ Seq(0.004476514990225833, 0.16575416718561453, 0d, 2.55351295663786e-15)))))
+ }
+
test("ftvec.conv - quantify") {
import hiveContext.implicits._
val testDf = Seq((1, "aaa", true), (2, "bbb", false), (3, "aaa", false)).toDF
@@ -342,6 +356,18 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
checkAnswer(predicted, Seq(Row(0), Row(1)))
}
+ test("tools.array - select_k_best") {
+ import hiveContext.implicits._
+
+ val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
+ val importance = Seq(3, 1, 2)
+ val k = 2
+ val df = data.toDF("features")
+
+ assert(df.select(select_k_best(df("features"), importance, k)).collect.toSeq ===
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
+ }
+
test("misc - sigmoid") {
import hiveContext.implicits._
/**
@@ -631,6 +657,39 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val row4 = df4.groupby($"c0").f1score("c1", "c2").collect
assert(row4(0).getDouble(1) ~== 0.25)
}
+
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF.as("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ assert(row0(0).getAs[Seq[Double]](1) ===
+ Seq(8.431818181818192, 1.3212121212121217, 42.94949494949499, 33.80952380952378))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))).toDF.as("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect
+ assert(row0(0).getAs[Seq[Double]](1) === Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))
+ }
}
final class HivemallOpsWithVectorSuite extends VectorQueryTest {
[50/50] [abbrv] incubator-hivemall git commit: Merge branch
'AddOptimizers' of https://github.com/maropu/hivemall into JIRA-22/pr-285
Posted by my...@apache.org.
Merge branch 'AddOptimizers' of https://github.com/maropu/hivemall into JIRA-22/pr-285
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9ca8bce7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9ca8bce7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9ca8bce7
Branch: refs/heads/JIRA-22/pr-285
Commit: 9ca8bce75333a3ebb43d2492bf5df815ccf869dc
Parents: 72d6a62 3620eb8
Author: myui <yu...@gmail.com>
Authored: Fri Dec 2 16:56:40 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Fri Dec 2 16:56:40 2016 +0900
----------------------------------------------------------------------
.../src/main/java/hivemall/LearnerBaseUDTF.java | 55 +++
.../hivemall/classifier/AROWClassifierUDTF.java | 2 +-
.../hivemall/classifier/AdaGradRDAUDTF.java | 6 +-
.../classifier/BinaryOnlineClassifierUDTF.java | 13 +
.../classifier/GeneralClassifierUDTF.java | 122 +++++
.../classifier/PassiveAggressiveUDTF.java | 2 +-
.../main/java/hivemall/common/EtaEstimator.java | 160 -------
.../java/hivemall/common/LossFunctions.java | 467 -------------------
.../java/hivemall/fm/FMHyperParameters.java | 2 +-
.../hivemall/fm/FactorizationMachineModel.java | 2 +-
.../hivemall/fm/FactorizationMachineUDTF.java | 8 +-
.../fm/FieldAwareFactorizationMachineModel.java | 1 +
.../hivemall/mf/BPRMatrixFactorizationUDTF.java | 2 +-
.../hivemall/mf/MatrixFactorizationSGDUDTF.java | 2 +-
.../main/java/hivemall/model/DenseModel.java | 5 +
.../main/java/hivemall/model/IWeightValue.java | 16 +-
.../main/java/hivemall/model/NewDenseModel.java | 293 ++++++++++++
.../model/NewSpaceEfficientDenseModel.java | 317 +++++++++++++
.../java/hivemall/model/NewSparseModel.java | 197 ++++++++
.../java/hivemall/model/PredictionModel.java | 2 +
.../model/SpaceEfficientDenseModel.java | 5 +
.../main/java/hivemall/model/SparseModel.java | 5 +
.../model/SynchronizedModelWrapper.java | 10 +
.../main/java/hivemall/model/WeightValue.java | 162 ++++++-
.../hivemall/model/WeightValueWithClock.java | 167 ++++++-
.../optimizer/DenseOptimizerFactory.java | 215 +++++++++
.../java/hivemall/optimizer/EtaEstimator.java | 191 ++++++++
.../java/hivemall/optimizer/LossFunctions.java | 467 +++++++++++++++++++
.../main/java/hivemall/optimizer/Optimizer.java | 246 ++++++++++
.../java/hivemall/optimizer/Regularization.java | 99 ++++
.../optimizer/SparseOptimizerFactory.java | 171 +++++++
.../hivemall/regression/AROWRegressionUDTF.java | 2 +-
.../java/hivemall/regression/AdaDeltaUDTF.java | 5 +-
.../java/hivemall/regression/AdaGradUDTF.java | 5 +-
.../regression/GeneralRegressionUDTF.java | 126 +++++
.../java/hivemall/regression/LogressUDTF.java | 10 +-
.../PassiveAggressiveRegressionUDTF.java | 2 +-
.../hivemall/regression/RegressionBaseUDTF.java | 26 +-
.../NewSpaceEfficientNewDenseModelTest.java | 60 +++
.../model/SpaceEfficientDenseModelTest.java | 60 ---
.../java/hivemall/optimizer/OptimizerTest.java | 172 +++++++
.../java/hivemall/mix/server/MixServerTest.java | 18 +-
resources/ddl/define-all-as-permanent.hive | 13 +-
resources/ddl/define-all.hive | 12 +-
.../hivemall/mix/server/MixServerSuite.scala | 6 +-
45 files changed, 3195 insertions(+), 734 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/LearnerBaseUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/classifier/AROWClassifierUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/classifier/AdaGradRDAUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/classifier/BinaryOnlineClassifierUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/classifier/PassiveAggressiveUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/fm/FMHyperParameters.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/mf/BPRMatrixFactorizationUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/mf/MatrixFactorizationSGDUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/DenseModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/IWeightValue.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/PredictionModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/SpaceEfficientDenseModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/SparseModel.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/SynchronizedModelWrapper.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/WeightValue.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/model/WeightValueWithClock.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/optimizer/EtaEstimator.java
----------------------------------------------------------------------
diff --cc core/src/main/java/hivemall/optimizer/EtaEstimator.java
index 0000000,ac1d112..a17c349
mode 000000,100644..100644
--- a/core/src/main/java/hivemall/optimizer/EtaEstimator.java
+++ b/core/src/main/java/hivemall/optimizer/EtaEstimator.java
@@@ -1,0 -1,191 +1,191 @@@
+ /*
- * Hivemall: Hive scalable Machine Learning Library
++ * Licensed to the Apache Software Foundation (ASF) under one
++ * or more contributor license agreements. See the NOTICE file
++ * distributed with this work for additional information
++ * regarding copyright ownership. The ASF licenses this file
++ * to you under the Apache License, Version 2.0 (the
++ * "License"); you may not use this file except in compliance
++ * with the License. You may obtain a copy of the License at
+ *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
++ * http://www.apache.org/licenses/LICENSE-2.0
+ *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
++ * Unless required by applicable law or agreed to in writing,
++ * software distributed under the License is distributed on an
++ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++ * KIND, either express or implied. See the License for the
++ * specific language governing permissions and limitations
++ * under the License.
+ */
+ package hivemall.optimizer;
+
+ import hivemall.utils.lang.NumberUtils;
+ import hivemall.utils.lang.Primitives;
+
+ import java.util.Map;
+ import javax.annotation.Nonnegative;
+ import javax.annotation.Nonnull;
+ import javax.annotation.Nullable;
+
+ import org.apache.commons.cli.CommandLine;
+ import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+
+ public abstract class EtaEstimator {
+
+ protected final float eta0;
+
+ public EtaEstimator(float eta0) {
+ this.eta0 = eta0;
+ }
+
+ public float eta0() {
+ return eta0;
+ }
+
+ public abstract float eta(long t);
+
+ public void update(@Nonnegative float multipler) {}
+
+ public static final class FixedEtaEstimator extends EtaEstimator {
+
+ public FixedEtaEstimator(float eta) {
+ super(eta);
+ }
+
+ @Override
+ public float eta(long t) {
+ return eta0;
+ }
+
+ }
+
+ public static final class SimpleEtaEstimator extends EtaEstimator {
+
+ private final float finalEta;
+ private final double total_steps;
+
+ public SimpleEtaEstimator(float eta0, long total_steps) {
+ super(eta0);
+ this.finalEta = (float) (eta0 / 2.d);
+ this.total_steps = total_steps;
+ }
+
+ @Override
+ public float eta(final long t) {
+ if (t > total_steps) {
+ return finalEta;
+ }
+ return (float) (eta0 / (1.d + (t / total_steps)));
+ }
+
+ }
+
+ public static final class InvscalingEtaEstimator extends EtaEstimator {
+
+ private final double power_t;
+
+ public InvscalingEtaEstimator(float eta0, double power_t) {
+ super(eta0);
+ this.power_t = power_t;
+ }
+
+ @Override
+ public float eta(final long t) {
+ return (float) (eta0 / Math.pow(t, power_t));
+ }
+
+ }
+
+ /**
+ * bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic
+ * gradient descent, KDD 2011.
+ */
+ public static final class AdjustingEtaEstimator extends EtaEstimator {
+
+ private float eta;
+
+ public AdjustingEtaEstimator(float eta) {
+ super(eta);
+ this.eta = eta;
+ }
+
+ @Override
+ public float eta(long t) {
+ return eta;
+ }
+
+ @Override
+ public void update(@Nonnegative float multipler) {
+ float newEta = eta * multipler;
+ if (!NumberUtils.isFinite(newEta)) {
+ // avoid NaN or INFINITY
+ return;
+ }
+ this.eta = Math.min(eta0, newEta); // never be larger than eta0
+ }
+
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
+ return get(cl, 0.1f);
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0)
+ throws UDFArgumentException {
+ if (cl == null) {
+ return new InvscalingEtaEstimator(defaultEta0, 0.1d);
+ }
+
+ if (cl.hasOption("boldDriver")) {
+ float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f);
+ return new AdjustingEtaEstimator(eta);
+ }
+
+ String etaValue = cl.getOptionValue("eta");
+ if (etaValue != null) {
+ float eta = Float.parseFloat(etaValue);
+ return new FixedEtaEstimator(eta);
+ }
+
+ float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
+ if (cl.hasOption("t")) {
+ long t = Long.parseLong(cl.getOptionValue("t"));
+ return new SimpleEtaEstimator(eta0, t);
+ }
+
+ double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1d);
+ return new InvscalingEtaEstimator(eta0, power_t);
+ }
+
+ @Nonnull
+ public static EtaEstimator get(@Nonnull final Map<String, String> options)
+ throws IllegalArgumentException {
+ final String etaName = options.get("eta");
+ if(etaName == null) {
+ return new FixedEtaEstimator(1.f);
+ }
+ float eta0 = 0.1f;
+ if(options.containsKey("eta0")) {
+ eta0 = Float.parseFloat(options.get("eta0"));
+ }
+ if(etaName.toLowerCase().equals("fixed")) {
+ return new FixedEtaEstimator(eta0);
+ } else if(etaName.toLowerCase().equals("simple")) {
+ long t = 10000;
+ if(options.containsKey("t")) {
+ t = Long.parseLong(options.get("t"));
+ }
+ return new SimpleEtaEstimator(eta0, t);
+ } else if(etaName.toLowerCase().equals("inverse")) {
+ double power_t = 0.1;
+ if(options.containsKey("power_t")) {
+ power_t = Double.parseDouble(options.get("power_t"));
+ }
+ return new InvscalingEtaEstimator(eta0, power_t);
+ } else {
+ throw new IllegalArgumentException("Unsupported ETA name: " + etaName);
+ }
+ }
+
+ }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/optimizer/LossFunctions.java
----------------------------------------------------------------------
diff --cc core/src/main/java/hivemall/optimizer/LossFunctions.java
index 0000000,d11be9b..07f7cb8
mode 000000,100644..100644
--- a/core/src/main/java/hivemall/optimizer/LossFunctions.java
+++ b/core/src/main/java/hivemall/optimizer/LossFunctions.java
@@@ -1,0 -1,467 +1,467 @@@
+ /*
- * Hivemall: Hive scalable Machine Learning Library
++ * Licensed to the Apache Software Foundation (ASF) under one
++ * or more contributor license agreements. See the NOTICE file
++ * distributed with this work for additional information
++ * regarding copyright ownership. The ASF licenses this file
++ * to you under the Apache License, Version 2.0 (the
++ * "License"); you may not use this file except in compliance
++ * with the License. You may obtain a copy of the License at
+ *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
++ * http://www.apache.org/licenses/LICENSE-2.0
+ *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
++ * Unless required by applicable law or agreed to in writing,
++ * software distributed under the License is distributed on an
++ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++ * KIND, either express or implied. See the License for the
++ * specific language governing permissions and limitations
++ * under the License.
+ */
+ package hivemall.optimizer;
+
+ import hivemall.utils.math.MathUtils;
+
+ /**
+ * @link https://github.com/JohnLangford/vowpal_wabbit/wiki/Loss-functions
+ */
+ public final class LossFunctions {
+
+ public enum LossType {
+ SquaredLoss, LogLoss, HingeLoss, SquaredHingeLoss, QuantileLoss, EpsilonInsensitiveLoss
+ }
+
+ public static LossFunction getLossFunction(String type) {
+ if ("SquaredLoss".equalsIgnoreCase(type)) {
+ return new SquaredLoss();
+ } else if ("LogLoss".equalsIgnoreCase(type)) {
+ return new LogLoss();
+ } else if ("HingeLoss".equalsIgnoreCase(type)) {
+ return new HingeLoss();
+ } else if ("SquaredHingeLoss".equalsIgnoreCase(type)) {
+ return new SquaredHingeLoss();
+ } else if ("QuantileLoss".equalsIgnoreCase(type)) {
+ return new QuantileLoss();
+ } else if ("EpsilonInsensitiveLoss".equalsIgnoreCase(type)) {
+ return new EpsilonInsensitiveLoss();
+ }
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+
+ public static LossFunction getLossFunction(LossType type) {
+ switch (type) {
+ case SquaredLoss:
+ return new SquaredLoss();
+ case LogLoss:
+ return new LogLoss();
+ case HingeLoss:
+ return new HingeLoss();
+ case SquaredHingeLoss:
+ return new SquaredHingeLoss();
+ case QuantileLoss:
+ return new QuantileLoss();
+ case EpsilonInsensitiveLoss:
+ return new EpsilonInsensitiveLoss();
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+ }
+
+ public interface LossFunction {
+
+ /**
+ * Evaluate the loss function.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The loss evaluated at `p` and `y`.
+ */
+ public float loss(float p, float y);
+
+ public double loss(double p, double y);
+
+ /**
+ * Evaluate the derivative of the loss function with respect to the prediction `p`.
+ *
+ * @param p The prediction, p = w^T x
+ * @param y The true value (aka target)
+ * @return The derivative of the loss function w.r.t. `p`.
+ */
+ public float dloss(float p, float y);
+
+ public boolean forBinaryClassification();
+
+ public boolean forRegression();
+
+ }
+
+ public static abstract class BinaryLoss implements LossFunction {
+
+ protected static void checkTarget(float y) {
+ if (!(y == 1.f || y == -1.f)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ protected static void checkTarget(double y) {
+ if (!(y == 1.d || y == -1.d)) {
+ throw new IllegalArgumentException("target must be [+1,-1]: " + y);
+ }
+ }
+
+ @Override
+ public boolean forBinaryClassification() {
+ return true;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return false;
+ }
+ }
+
+ public static abstract class RegressionLoss implements LossFunction {
+
+ @Override
+ public boolean forBinaryClassification() {
+ return false;
+ }
+
+ @Override
+ public boolean forRegression() {
+ return true;
+ }
+
+ }
+
+ /**
+ * Squared loss for regression problems.
+ *
+ * If you're trying to minimize the mean error, use squared-loss.
+ */
+ public static final class SquaredLoss extends RegressionLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ return p - y; // 2 (p - y) / 2
+ }
+ }
+
+ /**
+ * Logistic regression loss for binary classification with y in {-1, 1}.
+ */
+ public static final class LogLoss extends BinaryLoss {
+
+ /**
+ * <code>logloss(p,y) = log(1+exp(-p*y))</code>
+ */
+ @Override
+ public float loss(float p, float y) {
+ checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z) * -y;
+ }
+ if (z < -18.f) {
+ return -y;
+ }
+ return -y / ((float) Math.exp(z) + 1.f);
+ }
+ }
+
+ /**
+ * Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class HingeLoss extends BinaryLoss {
+
+ private float threshold;
+
+ public HingeLoss() {
+ this(1.f);
+ }
+
+ /**
+ * @param threshold Margin threshold. When threshold=1.0, one gets the loss used by SVM.
+ * When threshold=0.0, one gets the loss used by the Perceptron.
+ */
+ public HingeLoss(float threshold) {
+ this.threshold = threshold;
+ }
+
+ public void setThreshold(float threshold) {
+ this.threshold = threshold;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = hingeLoss(p, y, threshold);
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float loss = hingeLoss(p, y, threshold);
+ return (loss > 0.f) ? -y : 0.f;
+ }
+ }
+
+ /**
+ * Squared Hinge loss for binary classification tasks with y in {-1,1}.
+ */
+ public static final class SquaredHingeLoss extends BinaryLoss {
+
+ @Override
+ public float loss(float p, float y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ return squaredHingeLoss(p, y);
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ checkTarget(y);
+
+ float d = 1 - (y * p);
+ return (d > 0.f) ? -2.f * d * y : 0.f;
+ }
+
+ }
+
+ /**
+ * Quantile loss is useful to predict rank/order and you do not mind the mean error to increase
+ * as long as you get the relative order correct.
+ *
+ * @link http://en.wikipedia.org/wiki/Quantile_regression
+ */
+ public static final class QuantileLoss extends RegressionLoss {
+
+ private float tau;
+
+ public QuantileLoss() {
+ this.tau = 0.5f;
+ }
+
+ public QuantileLoss(float tau) {
+ setTau(tau);
+ }
+
+ public void setTau(float tau) {
+ if (tau <= 0 || tau >= 1.0) {
+ throw new IllegalArgumentException("tau must be in range (0, 1): " + tau);
+ }
+ this.tau = tau;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float e = y - p;
+ if (e > 0.f) {
+ return tau * e;
+ } else {
+ return -(1.f - tau) * e;
+ }
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double e = y - p;
+ if (e > 0.d) {
+ return tau * e;
+ } else {
+ return -(1.d - tau) * e;
+ }
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ float e = y - p;
+ if (e == 0.f) {
+ return 0.f;
+ }
+ return (e > 0.f) ? -tau : (1.f - tau);
+ }
+
+ }
+
+ /**
+ * Epsilon-Insensitive loss used by Support Vector Regression (SVR).
+ * <code>loss = max(0, |y - p| - epsilon)</code>
+ */
+ public static final class EpsilonInsensitiveLoss extends RegressionLoss {
+
+ private float epsilon;
+
+ public EpsilonInsensitiveLoss() {
+ this(0.1f);
+ }
+
+ public EpsilonInsensitiveLoss(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ public void setEpsilon(float epsilon) {
+ this.epsilon = epsilon;
+ }
+
+ @Override
+ public float loss(float p, float y) {
+ float loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.f) ? loss : 0.f;
+ }
+
+ @Override
+ public double loss(double p, double y) {
+ double loss = Math.abs(y - p) - epsilon;
+ return (loss > 0.d) ? loss : 0.d;
+ }
+
+ @Override
+ public float dloss(float p, float y) {
+ if ((y - p) > epsilon) {// real value > predicted value - epsilon
+ return -1.f;
+ }
+ if ((p - y) > epsilon) {// real value < predicted value - epsilon
+ return 1.f;
+ }
+ return 0.f;
+ }
+
+ }
+
+ public static float logisticLoss(final float target, final float predicted) {
+ if (predicted > -100.d) {
+ return target - (float) MathUtils.sigmoid(predicted);
+ } else {
+ return target;
+ }
+ }
+
+ public static float logLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ final float z = y * p;
+ if (z > 18.f) {
+ return (float) Math.exp(-z);
+ }
+ if (z < -18.f) {
+ return -z;
+ }
+ return (float) Math.log(1.d + Math.exp(-z));
+ }
+
+ public static double logLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ final double z = y * p;
+ if (z > 18.d) {
+ return Math.exp(-z);
+ }
+ if (z < -18.d) {
+ return -z;
+ }
+ return Math.log(1.d + Math.exp(-z));
+ }
+
+ public static float squaredLoss(float p, float y) {
+ final float z = p - y;
+ return z * z * 0.5f;
+ }
+
+ public static double squaredLoss(double p, double y) {
+ final double z = p - y;
+ return z * z * 0.5d;
+ }
+
+ public static float hingeLoss(final float p, final float y, final float threshold) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ return threshold - z;
+ }
+
+ public static double hingeLoss(final double p, final double y, final double threshold) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ return threshold - z;
+ }
+
+ public static float hingeLoss(float p, float y) {
+ return hingeLoss(p, y, 1.f);
+ }
+
+ public static double hingeLoss(double p, double y) {
+ return hingeLoss(p, y, 1.d);
+ }
+
+ public static float squaredHingeLoss(final float p, final float y) {
+ BinaryLoss.checkTarget(y);
+
+ float z = y * p;
+ float d = 1.f - z;
+ return (d > 0.f) ? (d * d) : 0.f;
+ }
+
+ public static double squaredHingeLoss(final double p, final double y) {
+ BinaryLoss.checkTarget(y);
+
+ double z = y * p;
+ double d = 1.d - z;
+ return (d > 0.d) ? d * d : 0.d;
+ }
+
+ /**
+ * Math.abs(target - predicted) - epsilon
+ */
+ public static float epsilonInsensitiveLoss(float predicted, float target, float epsilon) {
+ return Math.abs(target - predicted) - epsilon;
+ }
+ }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/AROWRegressionUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/AdaDeltaUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/AdaGradUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/LogressUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/PassiveAggressiveRegressionUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/main/java/hivemall/regression/RegressionBaseUDTF.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
----------------------------------------------------------------------
diff --cc core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
index 0000000,dd9c4ec..c892071
mode 000000,100644..100644
--- a/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
+++ b/core/src/test/java/hivemall/model/NewSpaceEfficientNewDenseModelTest.java
@@@ -1,0 -1,60 +1,60 @@@
+ /*
- * Hivemall: Hive scalable Machine Learning Library
++ * Licensed to the Apache Software Foundation (ASF) under one
++ * or more contributor license agreements. See the NOTICE file
++ * distributed with this work for additional information
++ * regarding copyright ownership. The ASF licenses this file
++ * to you under the Apache License, Version 2.0 (the
++ * "License"); you may not use this file except in compliance
++ * with the License. You may obtain a copy of the License at
+ *
- * Copyright (C) 2015 Makoto YUI
- * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
++ * http://www.apache.org/licenses/LICENSE-2.0
+ *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
++ * Unless required by applicable law or agreed to in writing,
++ * software distributed under the License is distributed on an
++ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++ * KIND, either express or implied. See the License for the
++ * specific language governing permissions and limitations
++ * under the License.
+ */
+ package hivemall.model;
+
+ import static org.junit.Assert.assertEquals;
+ import hivemall.utils.collections.IMapIterator;
+ import hivemall.utils.lang.HalfFloat;
+
+ import java.util.Random;
+
+ import org.junit.Test;
+
+ public class NewSpaceEfficientNewDenseModelTest {
+
+ @Test
+ public void testGetSet() {
+ final int size = 1 << 12;
+
+ final NewSpaceEfficientDenseModel model1 = new NewSpaceEfficientDenseModel(size);
+ //model1.configureClock();
+ final NewDenseModel model2 = new NewDenseModel(size);
+ //model2.configureClock();
+
+ final Random rand = new Random();
+ for (int t = 0; t < 1000; t++) {
+ int i = rand.nextInt(size);
+ float f = HalfFloat.MAX_FLOAT * rand.nextFloat();
+ IWeightValue w = new WeightValue(f);
+ model1.set(i, w);
+ model2.set(i, w);
+ }
+
+ assertEquals(model2.size(), model1.size());
+
+ IMapIterator<Integer, IWeightValue> itor = model1.entries();
+ while (itor.next() != -1) {
+ int k = itor.getKey();
+ float expected = itor.getValue().get();
+ float actual = model2.getWeight(k);
+ assertEquals(expected, actual, 32f);
+ }
+ }
+
+ }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/mixserv/src/test/java/hivemall/mix/server/MixServerTest.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/resources/ddl/define-all-as-permanent.hive
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/resources/ddl/define-all.hive
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9ca8bce7/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
----------------------------------------------------------------------
diff --cc spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
index dbb818b,c0ee72f..3d53bec
--- a/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
+++ b/spark/spark-2.0/src/test/scala/hivemall/mix/server/MixServerSuite.scala
@@@ -19,13 -17,16 +19,15 @@@
package hivemall.mix.server
import java.util.Random
-import java.util.concurrent.{TimeUnit, ExecutorService, Executors}
+import java.util.concurrent.{Executors, ExecutorService, TimeUnit}
import java.util.logging.Logger
+ import org.scalatest.{BeforeAndAfter, FunSuite}
+
-import hivemall.model.{NewDenseModel, PredictionModel, WeightValue}
import hivemall.mix.MixMessage.MixEventName
import hivemall.mix.client.MixClient
import hivemall.mix.server.MixServer.ServerState
- import hivemall.model.{DenseModel, PredictionModel, WeightValue}
++import hivemall.model.{NewDenseModel, PredictionModel, WeightValue}
import hivemall.utils.io.IOUtils
import hivemall.utils.lang.CommandLineUtils
import hivemall.utils.net.NetUtils
[20/50] [abbrv] incubator-hivemall git commit: add tests
Posted by my...@apache.org.
add tests
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/5088ef36
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/5088ef36
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/5088ef36
Branch: refs/heads/JIRA-22/pr-385
Commit: 5088ef36367df1cd51ae62f1c044933676975e2e
Parents: a882c5f
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 16:22:09 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 18:00:35 2016 +0900
----------------------------------------------------------------------
.../tools/matrix/TransposeAndDotUDAF.java | 2 +-
.../ftvec/selection/ChiSquareUDFTest.java | 80 ++++++++++++++++++++
.../tools/array/SelectKBeatUDFTest.java | 65 ++++++++++++++++
.../tools/matrix/TransposeAndDotUDAFTest.java | 58 ++++++++++++++
4 files changed, 204 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5088ef36/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
index 9d68f93..9df9305 100644
--- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -70,7 +70,7 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
return new TransposeAndDotUDAFEvaluator();
}
- private static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
+ static final class TransposeAndDotUDAFEvaluator extends GenericUDAFEvaluator {
// PARTIAL1 and COMPLETE
private ListObjectInspector matrix0RowOI;
private PrimitiveObjectInspector matrix0ElOI;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5088ef36/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
new file mode 100644
index 0000000..38f7f57
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/selection/ChiSquareUDFTest.java
@@ -0,0 +1,80 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.selection;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class ChiSquareUDFTest {
+
+ @Test
+ public void test() throws Exception {
+ // this test is based on iris data set
+ final ChiSquareUDF chi2 = new ChiSquareUDF();
+ final List<List<DoubleWritable>> observed = new ArrayList<List<DoubleWritable>>();
+ final List<List<DoubleWritable>> expected = new ArrayList<List<DoubleWritable>>();
+ final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] {
+ new GenericUDF.DeferredJavaObject(observed),
+ new GenericUDF.DeferredJavaObject(expected)};
+
+ final double[][] matrix0 = new double[][] {
+ {250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996},
+ {296.8, 138.50000000000003, 212.99999999999997, 66.3},
+ {329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998}};
+ final double[][] matrix1 = new double[][] {
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589},
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589},
+ {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}};
+
+ for (double[] row : matrix0) {
+ observed.add(WritableUtils.toWritableList(row));
+ }
+ for (double[] row : matrix1) {
+ expected.add(WritableUtils.toWritableList(row));
+ }
+
+ chi2.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)),
+ ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))});
+ final Object[] result = (Object[]) chi2.evaluate(dObjs);
+ final double[] result0 = new double[matrix0[0].length];
+ final double[] result1 = new double[matrix0[0].length];
+ for (int i = 0; i < result0.length; i++) {
+ result0[i] = Double.valueOf(((List) result[0]).get(i).toString());
+ result1[i] = Double.valueOf(((List) result[1]).get(i).toString());
+ }
+
+ final double[] answer0 = new double[] {10.817820878493995, 3.5944990176817315,
+ 116.16984746363957, 67.24482558215503};
+ final double[] answer1 = new double[] {0.004476514990225833, 0.16575416718561453, 0.d,
+ 2.55351295663786e-15};
+
+ Assert.assertArrayEquals(answer0, result0, 0.d);
+ Assert.assertArrayEquals(answer1, result1, 0.d);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5088ef36/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
new file mode 100644
index 0000000..b86db5c
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java
@@ -0,0 +1,65 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.tools.array;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.io.IntWritable;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.List;
+
+public class SelectKBeatUDFTest {
+
+ @Test
+ public void test() throws Exception {
+ final SelectKBestUDF selectKBest = new SelectKBestUDF();
+ final int k = 2;
+ final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2,
+ 12.199999999999996};
+ final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467,
+ 187.93333893418327, 59.93333511948589};
+
+ final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] {
+ new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)),
+ new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)),
+ new GenericUDF.DeferredJavaObject(new IntWritable(k))};
+
+ selectKBest.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ PrimitiveObjectInspectorFactory.writableIntObjectInspector});
+ final List resultObj = (List) selectKBest.evaluate(dObjs);
+
+ Assert.assertEquals(resultObj.size(), k);
+
+ final double[] result = new double[k];
+ for (int i = 0; i < k; i++) {
+ result[i] = Double.valueOf(resultObj.get(i).toString());
+ }
+
+ final double[] answer = new double[] {250.29999999999998, 73.2};
+
+ Assert.assertArrayEquals(answer, result, 0.d);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/5088ef36/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
new file mode 100644
index 0000000..93c6ef1
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/matrix/TransposeAndDotUDAFTest.java
@@ -0,0 +1,58 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2016 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.tools.matrix;
+
+import hivemall.utils.hadoop.WritableUtils;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TransposeAndDotUDAFTest {
+
+ @Test
+ public void test() throws Exception {
+ final TransposeAndDotUDAF tad = new TransposeAndDotUDAF();
+
+ final double[][] matrix0 = new double[][] { {1, -2}, {-1, 3}};
+ final double[][] matrix1 = new double[][] { {1, 2}, {3, 4}};
+
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)};
+ final GenericUDAFEvaluator evaluator = tad.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer agg = (TransposeAndDotUDAF.TransposeAndDotUDAFEvaluator.TransposeAndDotAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+ for (int i = 0; i < matrix0.length; i++) {
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(matrix0[i]),
+ WritableUtils.toWritableList(matrix1[i])});
+ }
+
+ final double[][] answer = new double[][] { {-2.0, -2.0}, {7.0, 8.0}};
+
+ for (int i = 0; i < answer.length; i++) {
+ Assert.assertArrayEquals(answer[i], agg.aggMatrix[i], 0.d);
+ }
+ }
+}
[46/50] [abbrv] incubator-hivemall git commit: Fix syntax errors in
spark (#387)
Posted by my...@apache.org.
Fix syntax errors in spark (#387)
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4c8dcbfc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4c8dcbfc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4c8dcbfc
Branch: refs/heads/JIRA-22/pr-385
Commit: 4c8dcbfcdd9dd584fc97e28db39a12d12dfd7b48
Parents: 6549ef5
Author: Takeshi Yamamuro <li...@gmail.com>
Authored: Thu Nov 24 03:13:25 2016 +0900
Committer: Makoto YUI <yu...@gmail.com>
Committed: Thu Nov 24 03:13:25 2016 +0900
----------------------------------------------------------------------
.../apache/spark/sql/hive/GroupedDataEx.scala | 8 +--
.../org/apache/spark/sql/hive/HivemallOps.scala | 6 +--
.../spark/sql/hive/HivemallOpsSuite.scala | 7 ++-
.../spark/sql/hive/HivemallGroupedDataset.scala | 51 ++++++++++----------
.../spark/sql/hive/HivemallOpsSuite.scala | 13 ++---
5 files changed, 41 insertions(+), 44 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4c8dcbfc/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index 8f78a7f..dd6db6c 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@ -271,9 +271,11 @@ final class GroupedDataEx protected[sql](
*/
def onehot_encoding(features: String*): DataFrame = {
val udaf = HiveUDAFFunction(
- new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
- features.map(df.col(_).expr),
- isUDAFBridgeRequired = false)
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ features.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
/**
* @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4c8dcbfc/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 27cffc7..8583e1c 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1010,9 +1010,9 @@ object HivemallOps {
}
/**
- * @see hivemall.ftvec.selection.ChiSquareUDF
- * @group ftvec.selection
- */
+ * @see hivemall.ftvec.selection.ChiSquareUDF
+ * @group ftvec.selection
+ */
def chi2(observed: Column, expected: Column): Column = {
HiveGenericUDF(new HiveFunctionWrapper(
"hivemall.ftvec.selection.ChiSquareUDF"), Seq(observed.expr, expected.expr))
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4c8dcbfc/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index c231105..4c77f18 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, Row}
import org.apache.spark.test.HivemallQueryTest
import org.apache.spark.test.TestDoubleWrapper._
import org.apache.spark.test.TestUtils._
@@ -575,14 +574,13 @@ final class HivemallOpsSuite extends HivemallQueryTest {
assert(row4(0).getDouble(1) ~== 0.25)
}
- test("user-defined aggregators for ftvec.trans") {
+ ignore("user-defined aggregators for ftvec.trans") {
import hiveContext.implicits._
val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
(1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
(1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
- .toDF("col0", "cat1", "cat2", "cat3")
-
+ .toDF("col0", "cat1", "cat2", "cat3")
val row00 = df0.groupby($"col0").onehot_encoding("cat1")
val row01 = df0.groupby($"col0").onehot_encoding("cat1", "cat2", "cat3")
@@ -600,6 +598,7 @@ final class HivemallOpsSuite extends HivemallQueryTest {
assert(result011.values.toSet === Set(6, 7, 8))
assert(result012.keySet === Set(1, 3, 9, 10, 101))
assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
+ }
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4c8dcbfc/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index 73757f6..bdeff98 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@ -133,6 +133,19 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
}
/**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "transpose_and_dot",
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
* @see hivemall.ftvec.trans.OnehotEncodingUDAF
* @group ftvec.trans
*/
@@ -147,6 +160,19 @@ final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
}
/**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "snr",
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
* @see hivemall.evaluation.MeanAbsoluteErrorUDAF
* @group evaluation
*/
@@ -273,30 +299,5 @@ object HivemallGroupedDataset {
implicit def relationalGroupedDatasetToHivemallOne(
groupBy: RelationalGroupedDataset): HivemallGroupedDataset = {
new HivemallGroupedDataset(groupBy)
-
- /**
- * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
- */
- def snr(X: String, Y: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "snr",
- new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
- Seq(X, Y).map(df.col(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Seq(Alias(udaf, udaf.prettyName)()))
- }
-
- /**
- * @see hivemall.tools.matrix.TransposeAndDotUDAF
- */
- def transpose_and_dot(X: String, Y: String): DataFrame = {
- val udaf = HiveUDAFFunction(
- "transpose_and_dot",
- new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
- Seq(X, Y).map(df.col(_).expr),
- isUDAFBridgeRequired = false)
- .toAggregateExpression()
- toDF(Seq(Alias(udaf, udaf.prettyName)()))
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4c8dcbfc/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 8bea975..d969abf 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -26,12 +26,6 @@ import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._
import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
import org.apache.spark.test.TestDoubleWrapper._
-import org.apache.spark.sql.hive.HivemallOps._
-import org.apache.spark.sql.hive.HivemallUtils._
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.{AnalysisException, Column, Row, functions}
-import org.apache.spark.test.TestDoubleWrapper._
-import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
@@ -705,6 +699,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
assert(result011.values.toSet === Set(6, 7, 8))
assert(result012.keySet === Set(1, 3, 9, 10, 101))
assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
+ }
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
@@ -726,7 +721,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
(1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
(1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
.toDF("c0", "arg0", "arg1")
- val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ val row0 = df0.groupBy($"c0").snr("arg0", "arg1").collect
(row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
.zipped
.foreach((actual, expected) => assert(actual ~== expected))
@@ -747,7 +742,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
(1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
(1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
.toDF("c0", "arg0", "arg1")
- val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ val row1 = df1.groupBy($"c0").snr("arg0", "arg1").collect
(row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
.zipped
.foreach((actual, expected) => assert(actual ~== expected))
@@ -761,7 +756,7 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
.toDF("c0", "arg0", "arg1")
- checkAnswer(df0.groupby($"c0").transpose_and_dot("arg0", "arg1"),
+ checkAnswer(df0.groupBy($"c0").transpose_and_dot("arg0", "arg1"),
Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}
[36/50] [abbrv] incubator-hivemall git commit: Add exception
Posted by my...@apache.org.
Add exception
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/33eab26f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/33eab26f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/33eab26f
Branch: refs/heads/JIRA-22/pr-336
Commit: 33eab26f383dbdbce00a209e742b611a63d953cf
Parents: ba91267
Author: amaya <gi...@sapphire.in.net>
Authored: Thu Nov 17 14:16:14 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Thu Nov 17 14:16:14 2016 +0900
----------------------------------------------------------------------
.../main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java | 1 +
1 file changed, 1 insertion(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33eab26f/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
index 6b41855..25a2125 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/HiveSystemTestRunner.java
@@ -132,6 +132,7 @@ public class HiveSystemTestRunner extends SystemTestRunner {
hShell.insertInto(dbName, hq.tableName).addRowsFromTsv(hq.file).commit();
break;
case MSGPACK:
+ throw new Exception("MessagePack is not supported in HiveSystemTestRunner");
case UNKNOWN:
throw new Exception("Input csv or tsv");
}
[27/50] [abbrv] incubator-hivemall git commit: refine feature
selection in spark integration
Posted by my...@apache.org.
refine feature selection in spark integration
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1347de98
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1347de98
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1347de98
Branch: refs/heads/JIRA-22/pr-385
Commit: 1347de985ea6f8028c9d381f8827882ad39ad3a7
Parents: aa7d529
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 28 14:22:05 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 28 14:22:05 2016 +0900
----------------------------------------------------------------------
.../org/apache/spark/sql/hive/HivemallOps.scala | 9 +-
.../spark/sql/hive/HivemallOpsSuite.scala | 94 ++++++++++++++------
.../org/apache/spark/sql/hive/HivemallOps.scala | 8 +-
.../spark/sql/hive/HivemallOpsSuite.scala | 89 ++++++++++++------
4 files changed, 138 insertions(+), 62 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1347de98/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 41a4065..255f697 100644
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1006,9 +1006,9 @@ object HivemallOps {
* @see hivemall.ftvec.selection.ChiSquareUDF
* @group ftvec.selection
*/
- def chi2(exprs: Column*): Column = {
+ def chi2(observed: Column, expected: Column): Column = {
HiveGenericUDF(new HiveFunctionWrapper(
- "hivemall.ftvec.selection.ChiSquareUDF"), exprs.map(_.expr))
+ "hivemall.ftvec.selection.ChiSquareUDF"), Seq(observed.expr, expected.expr))
}
/**
@@ -1087,10 +1087,9 @@ object HivemallOps {
* @see hivemall.tools.array.SelectKBestUDF
* @group tools.array
*/
- @scala.annotation.varargs
- def select_k_best(exprs: Column*): Column = {
+ def select_k_best(X: Column, importanceList: Column, k: Column): Column = {
HiveGenericUDF(new HiveFunctionWrapper(
- "hivemall.tools.array.SelectKBestUDF"), exprs.map(_.expr))
+ "hivemall.tools.array.SelectKBestUDF"), Seq(X.expr, importanceList.expr, k.expr))
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1347de98/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index e118257..cce22ce 100644
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Column, Row}
import org.apache.spark.test.HivemallQueryTest
import org.apache.spark.test.TestDoubleWrapper._
import org.apache.spark.test.TestUtils._
+import org.scalatest.Matchers._
final class HivemallOpsSuite extends HivemallQueryTest {
@@ -188,18 +189,32 @@ final class HivemallOpsSuite extends HivemallQueryTest {
test("ftvec.selection - chi2") {
import hiveContext.implicits._
-
- val df = Seq(Seq(
- Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
- Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
- Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)) -> Seq(
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))).toDF("arg0", "arg1")
-
- assert(df.select(chi2(df("arg0"), df("arg1"))).collect.toSet ===
- Set(Row(Row(Seq(10.817820878493995, 3.5944990176817315, 116.16984746363957, 67.24482558215503),
- Seq(0.004476514990225833, 0.16575416718561453, 0d, 2.55351295663786e-15)))))
+ implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
+
+ // see also hivemall.ftvec.selection.ChiSquareUDFTest
+ val df = Seq(
+ Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
+ ) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
+ .toDF("arg0", "arg1")
+
+ val result = df.select(chi2(df("arg0"), df("arg1"))).collect
+ result should have length 1
+ val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
+ val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
+
+ (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
+
+ (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
}
test("ftvec.conv - quantify") {
@@ -352,13 +367,11 @@ final class HivemallOpsSuite extends HivemallQueryTest {
test("tools.array - select_k_best") {
import hiveContext.implicits._
- val data = Seq(Tuple1(Seq(0, 1, 3)), Tuple1(Seq(2, 4, 1)), Tuple1(Seq(5, 4, 9)))
- val importance = Seq(3, 1, 2)
- val k = 2
- val df = data.toDF("features")
+ val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
+ val df = data.map(d => (d, Seq(3, 1, 2), 2)).toDF("features", "importance_list", "k")
- assert(df.select(select_k_best(df("features"), importance, k)).collect.toSeq ===
- data.map(s => Row(Seq(s._1(0).toDouble, s._1(2).toDouble))))
+ df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect shouldEqual
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble)))
}
test("misc - sigmoid") {
@@ -560,7 +573,31 @@ final class HivemallOpsSuite extends HivemallQueryTest {
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
+ implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
+ // multiple class
// +-----------------+-------+
// | features | class |
// +-----------------+-------+
@@ -571,14 +608,15 @@ final class HivemallOpsSuite extends HivemallQueryTest {
// | 6.3,3.3,6.0,2.5 | 2 |
// | 5.8,2.7,5.1,1.9 | 2 |
// +-----------------+-------+
- val df0 = Seq(
+ val df1 = Seq(
(1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
(1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
(1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
- .toDF.as("c0", "arg0", "arg1")
- val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
- assert(row0(0).getAs[Seq[Double]](1) ===
- Seq(8.431818181818192, 1.3212121212121217, 42.94949494949499, 33.80952380952378))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
}
test("user-defined aggregators for tools.matrix") {
@@ -586,8 +624,10 @@ final class HivemallOpsSuite extends HivemallQueryTest {
// | 1 2 3 |T | 5 6 7 |
// | 3 4 5 | * | 7 8 9 |
- val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))).toDF.as("c0", "arg0", "arg1")
- val row0 = df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect
- assert(row0(0).getAs[Seq[Double]](1) === Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() shouldEqual
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1347de98/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index f12992e..628c2ea 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -1252,10 +1252,10 @@ object HivemallOps {
* @see hivemall.ftvec.selection.ChiSquareUDF
* @group ftvec.selection
*/
- def chi2(exprs: Column*): Column = withExpr {
+ def chi2(observed: Column, expected: Column): Column = withExpr {
HiveGenericUDF("chi2",
new HiveFunctionWrapper("hivemall.ftvec.selection.ChiSquareUDF"),
- exprs.map(_.expr))
+ Seq(observed.expr, expected.expr))
}
/**
@@ -1341,10 +1341,10 @@ object HivemallOps {
* @see hivemall.tools.array.SelectKBestUDF
* @group tools.array
*/
- def select_k_best(exprs: Column*): Column = withExpr {
+ def select_k_best(X: Column, importanceList: Column, k: Column): Column = withExpr {
HiveGenericUDF("select_k_best",
new HiveFunctionWrapper("hivemall.tools.array.SelectKBestUDF"),
- exprs.map(_.expr))
+ Seq(X.expr, importanceList.expr, k.expr))
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1347de98/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index d750916..2e18280 100644
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{AnalysisException, Column, Row}
-import org.apache.spark.sql.functions
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._
-import org.apache.spark.test.HivemallFeatureQueryTest
+import org.apache.spark.sql.{AnalysisException, Column, Row, functions}
import org.apache.spark.test.TestDoubleWrapper._
-import org.apache.spark.test.{TestUtils, VectorQueryTest}
+import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
+import org.scalatest.Matchers._
final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
@@ -189,18 +188,32 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
test("ftvec.selection - chi2") {
import hiveContext.implicits._
+ implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
- val df = Seq(Seq(
- Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
- Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
- Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)) -> Seq(
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
- Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589))).toDF("arg0", "arg1")
+ // see also hivemall.ftvec.selection.ChiSquareUDFTest
+ val df = Seq(
+ Seq(
+ Seq(250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996),
+ Seq(296.8, 138.50000000000003, 212.99999999999997, 66.3),
+ Seq(329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998)
+ ) -> Seq(
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589),
+ Seq(292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589)))
+ .toDF("arg0", "arg1")
+
+ val result = df.select(chi2(df("arg0"), df("arg1"))).collect
+ result should have length 1
+ val chi2Val = result.head.getAs[Row](0).getAs[Seq[Double]](0)
+ val pVal = result.head.getAs[Row](0).getAs[Seq[Double]](1)
+
+ (chi2Val, Seq(10.81782088, 3.59449902, 116.16984746, 67.24482759))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
- assert(df.select(chi2(df("arg0"), df("arg1"))).collect.toSet ===
- Set(Row(Row(Seq(10.817820878493995, 3.5944990176817315, 116.16984746363957, 67.24482558215503),
- Seq(0.004476514990225833, 0.16575416718561453, 0d, 2.55351295663786e-15)))))
+ (pVal, Seq(4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
}
test("ftvec.conv - quantify") {
@@ -378,12 +391,10 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
import hiveContext.implicits._
val data = Seq(Seq(0, 1, 3), Seq(2, 4, 1), Seq(5, 4, 9))
- val importance = Seq(3, 1, 2)
- val k = 2
- val df = data.toDF("features")
+ val df = data.map(d => (d, Seq(3, 1, 2), 2)).toDF("features", "importance_list", "k")
- assert(df.select(select_k_best(df("features"), importance, k)).collect.toSeq ===
- data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble))))
+ df.select(select_k_best(df("features"), df("importance_list"), df("k"))).collect shouldEqual
+ data.map(s => Row(Seq(s(0).toDouble, s(2).toDouble)))
}
test("misc - sigmoid") {
@@ -678,7 +689,31 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
test("user-defined aggregators for ftvec.selection") {
import hiveContext.implicits._
+ implicit val doubleEquality = org.scalactic.TolerantNumerics.tolerantDoubleEquality(1e-5)
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
+
+ // multiple class
// +-----------------+-------+
// | features | class |
// +-----------------+-------+
@@ -689,14 +724,15 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
// | 6.3,3.3,6.0,2.5 | 2 |
// | 5.8,2.7,5.1,1.9 | 2 |
// +-----------------+-------+
- val df0 = Seq(
+ val df1 = Seq(
(1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
(1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
(1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
- .toDF.as("c0", "arg0", "arg1")
- val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
- assert(row0(0).getAs[Seq[Double]](1) ===
- Seq(8.431818181818192, 1.3212121212121217, 42.94949494949499, 33.80952380952378))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => actual shouldEqual expected)
}
test("user-defined aggregators for tools.matrix") {
@@ -705,8 +741,9 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
// | 1 2 3 |T | 5 6 7 |
// | 3 4 5 | * | 7 8 9 |
val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))).toDF.as("c0", "arg0", "arg1")
- val row0 = df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect
- assert(row0(0).getAs[Seq[Double]](1) === Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))
+
+ df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() shouldEqual
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))
}
}
[38/50] [abbrv] incubator-hivemall git commit: Mod assert methods
Posted by my...@apache.org.
Mod assert methods
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3550fd30
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3550fd30
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3550fd30
Branch: refs/heads/JIRA-22/pr-336
Commit: 3550fd30af3a01f4217c075a3b814952b406aebe
Parents: 1f3df54
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 01:57:47 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 01:57:47 2016 +0900
----------------------------------------------------------------------
.../main/java/hivemall/systemtest/runner/SystemTestRunner.java | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3550fd30/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
index 77091f2..f16da90 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/SystemTestRunner.java
@@ -195,12 +195,10 @@ public abstract class SystemTestRunner extends ExternalResource {
if (ordered) {
// take order into consideration (like list)
- Assert.assertThat(Arrays.asList(answer.split(IO.RD)),
- Matchers.contains(result.toArray()));
+ Assert.assertThat(result, Matchers.contains(answer.split(IO.RD)));
} else {
// not take order into consideration (like multiset)
- Assert.assertThat(Arrays.asList(answer.split(IO.RD)),
- Matchers.containsInAnyOrder(result.toArray()));
+ Assert.assertThat(result, Matchers.containsInAnyOrder(answer.split(IO.RD)));
}
}
[04/50] [abbrv] incubator-hivemall git commit: mod number format
Posted by my...@apache.org.
mod number format
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/d8f1005b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/d8f1005b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/d8f1005b
Branch: refs/heads/JIRA-22/pr-385
Commit: d8f1005bb9fbf769b117290582bed18d7607a94a
Parents: d3009be
Author: amaya <gi...@sapphire.in.net>
Authored: Tue Sep 20 12:01:46 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Tue Sep 20 12:01:46 2016 +0900
----------------------------------------------------------------------
.../hivemall/tools/matrix/TransposeAndDotUDAF.java | 2 +-
.../src/main/java/hivemall/utils/math/StatsUtils.java | 14 +++++++-------
2 files changed, 8 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d8f1005b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
index 4fa5ce4..3dcbb93 100644
--- a/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
+++ b/core/src/main/java/hivemall/tools/matrix/TransposeAndDotUDAF.java
@@ -81,7 +81,7 @@ public final class TransposeAndDotUDAF extends AbstractGenericUDAFResolver {
public void reset() {
if (aggMatrix != null) {
for (double[] row : aggMatrix) {
- Arrays.fill(row, 0.0);
+ Arrays.fill(row, 0.d);
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/d8f1005b/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java
index ffccea3..7633419 100644
--- a/core/src/main/java/hivemall/utils/math/StatsUtils.java
+++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java
@@ -198,22 +198,22 @@ public final class StatsUtils {
public static double chiSquare(@Nonnull final double[] expected, @Nonnull final double[] observed) {
Preconditions.checkArgument(expected.length == observed.length);
- double sumExpected = 0.0D;
- double sumObserved = 0.0D;
+ double sumExpected = 0.d;
+ double sumObserved = 0.d;
for (int ratio = 0; ratio < observed.length; ++ratio) {
sumExpected += expected[ratio];
sumObserved += observed[ratio];
}
- double var15 = 1.0D;
+ double var15 = 1.d;
boolean rescale = false;
- if (Math.abs(sumExpected - sumObserved) > 1.0E-5D) {
+ if (Math.abs(sumExpected - sumObserved) > 1.e-5) {
var15 = sumObserved / sumExpected;
rescale = true;
}
- double sumSq = 0.0D;
+ double sumSq = 0.d;
for (int i = 0; i < observed.length; ++i) {
double dev;
@@ -235,7 +235,7 @@ public final class StatsUtils {
* @return p-value
*/
public static double chiSquareTest(@Nonnull final double[] expected,@Nonnull final double[] observed) {
- ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.0D);
- return 1.0D - distribution.cumulativeProbability(chiSquare(expected, observed));
+ ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.d);
+ return 1.d - distribution.cumulativeProbability(chiSquare(expected, observed));
}
}
[39/50] [abbrv] incubator-hivemall git commit: Fix process of tdprop
Posted by my...@apache.org.
Fix process of tdprop
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/144cb504
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/144cb504
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/144cb504
Branch: refs/heads/JIRA-22/pr-336
Commit: 144cb504d674d2509620ce0d315694be0f664f42
Parents: 3550fd3
Author: amaya <gi...@sapphire.in.net>
Authored: Fri Nov 18 01:58:31 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Fri Nov 18 01:58:31 2016 +0900
----------------------------------------------------------------------
.../systemtest/runner/TDSystemTestRunner.java | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/144cb504/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
----------------------------------------------------------------------
diff --git a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
index 6d6c85b..87dd835 100644
--- a/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
+++ b/systemtest/src/main/java/hivemall/systemtest/runner/TDSystemTestRunner.java
@@ -85,16 +85,20 @@ public class TDSystemTestRunner extends SystemTestRunner {
fileUploadCommitRetryLimit = Integer.valueOf(props.getProperty("fileUploadCommitRetryLimit"));
}
- final Properties TDPorps = System.getProperties();
+ boolean fromPropertiesFile = false;
for (Map.Entry<Object, Object> e : props.entrySet()) {
- if (e.getKey().toString().startsWith("td.client.")) {
- TDPorps.setProperty(e.getKey().toString(), e.getValue().toString());
+ final String key = e.getKey().toString();
+ if (key.startsWith("td.client.")) {
+ fromPropertiesFile = true;
+ System.setProperty(key, e.getValue().toString());
}
}
- System.setProperties(TDPorps);
- client = System.getProperties().size() == TDPorps.size() ? TDClient.newClient() // use $HOME/.td/td.conf
- : TDClient.newBuilder(false).build(); // use *.properties
+ if (fromPropertiesFile) {
+ client = TDClient.newBuilder(false).build(); // use *.properties
+ } else {
+ client = TDClient.newClient(); // use $HOME/.td/td.conf
+ }
}
@Override
[30/50] [abbrv] incubator-hivemall git commit: Support
implicit-Krylov-approximation-based efficient SST
Posted by my...@apache.org.
Support implicit-Krylov-approximation-based efficient SST
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/998203d5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/998203d5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/998203d5
Branch: refs/heads/JIRA-22/pr-356
Commit: 998203d5e8623d6282c2b187df24e4da7d41c16b
Parents: 2bfd127
Author: Takuya Kitazawa <k....@gmail.com>
Authored: Wed Sep 28 19:49:48 2016 +0900
Committer: Takuya Kitazawa <k....@gmail.com>
Committed: Wed Sep 28 19:49:48 2016 +0900
----------------------------------------------------------------------
.../anomaly/SingularSpectrumTransform.java | 103 ++++++++--
.../anomaly/SingularSpectrumTransformUDF.java | 27 +++
.../java/hivemall/utils/math/MatrixUtils.java | 203 +++++++++++++++++++
.../anomaly/SingularSpectrumTransformTest.java | 61 ++++--
.../hivemall/utils/math/MatrixUtilsTest.java | 67 ++++++
5 files changed, 434 insertions(+), 27 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/998203d5/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
index c964129..f9f6222 100644
--- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransform.java
@@ -18,9 +18,11 @@
package hivemall.anomaly;
import hivemall.anomaly.SingularSpectrumTransformUDF.SingularSpectrumTransformInterface;
+import hivemall.anomaly.SingularSpectrumTransformUDF.ScoreFunction;
import hivemall.anomaly.SingularSpectrumTransformUDF.Parameters;
import hivemall.utils.collections.DoubleRingBuffer;
-import org.apache.commons.math3.linear.MatrixUtils;
+import hivemall.utils.math.MatrixUtils;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.hadoop.hive.ql.metadata.HiveException;
@@ -28,6 +30,8 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import java.util.Arrays;
+import java.util.TreeMap;
+import java.util.Collections;
import javax.annotation.Nonnull;
@@ -37,6 +41,9 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
private final PrimitiveObjectInspector oi;
@Nonnull
+ private final ScoreFunction scoreFunc;
+
+ @Nonnull
private final int window;
@Nonnull
private final int nPastWindow;
@@ -50,15 +57,22 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
private final int currentOffset;
@Nonnull
private final int r;
+ @Nonnull
+ private final int k;
@Nonnull
private final DoubleRingBuffer xRing;
@Nonnull
private final double[] xSeries;
+ @Nonnull
+ private final double[] q;
+
SingularSpectrumTransform(@Nonnull Parameters params, @Nonnull PrimitiveObjectInspector oi) {
this.oi = oi;
+ this.scoreFunc = params.scoreFunc;
+
this.window = params.w;
this.nPastWindow = params.n;
this.nCurrentWindow = params.m;
@@ -66,6 +80,7 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
this.currentSize = window + nCurrentWindow;
this.currentOffset = params.g;
this.r = params.r;
+ this.k = params.k;
// (w + n) past samples for the n-past-windows
// (w + m) current samples for the m-current-windows, starting from offset g
@@ -74,6 +89,18 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
this.xRing = new DoubleRingBuffer(holdSampleSize);
this.xSeries = new double[holdSampleSize];
+
+ this.q = new double[window];
+ double norm = 0.d;
+ for (int i = 0; i < window; i++) {
+ this.q[i] = Math.random();
+ norm += q[i] * q[i];
+ }
+ norm = Math.sqrt(norm);
+ // normalize
+ for (int i = 0; i < window; i++) {
+ this.q[i] = q[i] / norm;
+ }
}
@Override
@@ -86,25 +113,39 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
if (!xRing.isFull()) {
outScores[0] = 0.d;
} else {
- outScores[0] = computeScore();
+ // create past trajectory matrix and find its left singular vectors
+ RealMatrix H = new Array2DRowRealMatrix(new double[window][nPastWindow]);
+ for (int i = 0; i < nPastWindow; i++) {
+ H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window));
+ }
+
+ // create current trajectory matrix and find its left singular vectors
+ RealMatrix G = new Array2DRowRealMatrix(new double[window][nCurrentWindow]);
+ int currentHead = pastSize + currentOffset;
+ for (int i = 0; i < nCurrentWindow; i++) {
+ G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window));
+ }
+
+ switch (scoreFunc) {
+ case svd:
+ outScores[0] = computeScoreSVD(H, G);
+ break;
+ case ika:
+ outScores[0] = computeScoreIKA(H, G);
+ break;
+ default:
+ throw new IllegalStateException("Unexpected score function: " + scoreFunc);
+ }
}
}
- private double computeScore() {
- // create past trajectory matrix and find its left singular vectors
- RealMatrix H = MatrixUtils.createRealMatrix(window, nPastWindow);
- for (int i = 0; i < nPastWindow; i++) {
- H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window));
- }
+ /**
+ * Singular Value Decomposition (SVD) based naive scoring.
+ */
+ private double computeScoreSVD(@Nonnull final RealMatrix H, @Nonnull final RealMatrix G) {
SingularValueDecomposition svdH = new SingularValueDecomposition(H);
RealMatrix UT = svdH.getUT();
- // create current trajectory matrix and find its left singular vectors
- RealMatrix G = MatrixUtils.createRealMatrix(window, nCurrentWindow);
- int currentHead = pastSize + currentOffset;
- for (int i = 0; i < nCurrentWindow; i++) {
- G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window));
- }
SingularValueDecomposition svdG = new SingularValueDecomposition(G);
RealMatrix Q = svdG.getU();
@@ -115,4 +156,38 @@ final class SingularSpectrumTransform implements SingularSpectrumTransformInterf
return 1.d - s[0];
}
+
+ /**
+ * Implicit Krylov Approximation (IKA) based naive scoring.
+ *
+ * Number of iterations for the Power method and QR method is fixed to 1 for efficiency.
+ * This may cause failure (i.e. meaningless scores) depending on datasets and initial values.
+ *
+ */
+ private double computeScoreIKA(@Nonnull final RealMatrix H, @Nonnull final RealMatrix G) {
+ // assuming n = m = window, and keep track the left singular vector as `q`
+ double firstSingularValue = MatrixUtils.power1(G, q, 1, q, new double[window]);
+
+ RealMatrix T = new Array2DRowRealMatrix(new double[k][k]);
+ MatrixUtils.lanczosTridiagonalization(H.multiply(H.transpose()), q, T);
+
+ double[] eigvals = new double[k];
+ RealMatrix eigvecs = new Array2DRowRealMatrix(new double[k][k]);
+ MatrixUtils.tridiagonalEigen(T, 1, eigvals, eigvecs);
+
+ // tridiagonalEigen() returns unordered eigenvalues,
+ // so the top-r eigenvectors should be picked carefully
+ TreeMap<Double, Integer> map = new TreeMap<Double, Integer>(Collections.reverseOrder());
+ for (int i = 0; i < k; i++) {
+ map.put(eigvals[i], i);
+ }
+ Object[] sortedIndices = map.values().toArray();
+
+ double s = 0.d;
+ for (int i = 0; i < r; i++) {
+ double v = eigvecs.getEntry(0, (int)sortedIndices[i]);
+ s += v * v;
+ }
+ return 1.d - Math.sqrt(s);
+ }
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/998203d5/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
index 64b7d20..5f8633d 100644
--- a/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
+++ b/core/src/main/java/hivemall/anomaly/SingularSpectrumTransformUDF.java
@@ -89,6 +89,8 @@ public final class SingularSpectrumTransformUDF extends UDFWithOptions {
"Number of singular vectors (i.e. principal components) [default: 3]");
opts.addOption("k", "n_dim", true,
"Number of dimensions for the Krylov subspaces [default: 5 (`2*r` if `r` is even, `2*r-1` otherwise)]");
+ opts.addOption("score", "scorefunc", true,
+ "Score function [default: svd, ika]");
opts.addOption("th", "threshold", true,
"Score threshold (inclusive) for determining change-point existence [default: -1, do not output decision]");
return opts;
@@ -105,6 +107,12 @@ public final class SingularSpectrumTransformUDF extends UDFWithOptions {
this._params.r = Primitives.parseInt(cl.getOptionValue("r"), _params.r);
this._params.k = Primitives.parseInt(
cl.getOptionValue("k"), (_params.r % 2 == 0) ? (2 * _params.r) : (2 * _params.r - 1));
+
+ this._params.scoreFunc = ScoreFunction.resolve(cl.getOptionValue("scorefunc", ScoreFunction.svd.name()));
+ if ((_params.w != _params.n || _params.w != _params.m) && _params.scoreFunc == ScoreFunction.ika) {
+ throw new UDFArgumentException("IKA-based efficient SST requires w = n = m");
+ }
+
this._params.changepointThreshold = Primitives.parseDouble(
cl.getOptionValue("th"), _params.changepointThreshold);
@@ -196,13 +204,32 @@ public final class SingularSpectrumTransformUDF extends UDFWithOptions {
int g = -30;
int r = 3;
int k = 5;
+ ScoreFunction scoreFunc = ScoreFunction.svd;
double changepointThreshold = -1.d;
Parameters() {}
+
+ void set(@Nonnull ScoreFunction func) {
+ this.scoreFunc = func;
+ }
}
public interface SingularSpectrumTransformInterface {
void update(@Nonnull Object arg, @Nonnull double[] outScores) throws HiveException;
}
+ public enum ScoreFunction {
+ svd, ika;
+
+ static ScoreFunction resolve(@Nullable final String name) {
+ if (svd.name().equalsIgnoreCase(name)) {
+ return svd;
+ } else if (ika.name().equalsIgnoreCase(name)) {
+ return ika;
+ } else {
+ throw new IllegalArgumentException("Unsupported ScoreFunction: " + name);
+ }
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/998203d5/core/src/main/java/hivemall/utils/math/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/math/MatrixUtils.java b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
index 840df41..aaf9d4a 100644
--- a/core/src/main/java/hivemall/utils/math/MatrixUtils.java
+++ b/core/src/main/java/hivemall/utils/math/MatrixUtils.java
@@ -26,11 +26,16 @@ import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.DefaultRealMatrixPreservingVisitor;
import org.apache.commons.math3.linear.LUDecomposition;
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixPreservingVisitor;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.linear.SingularValueDecomposition;
+import java.util.Arrays;
+
public final class MatrixUtils {
private MatrixUtils() {}
@@ -493,4 +498,202 @@ public final class MatrixUtils {
return A;
}
+ /**
+ * Find the first singular vector/value of a matrix A based on the Power method.
+ *
+ * http://www.cs.yale.edu/homes/el327/datamining2013aFiles/07_singular_value_decomposition.pdf
+ *
+ * @param A target matrix
+ * @param x0 initial vector
+ * @param nIter number of iterations for the Power method
+ * @param u 1st left singular vector
+ * @param v 1st right singular vector
+ * @return 1st singular value
+ */
+ @Nonnull
+ public static double power1(@Nonnull final RealMatrix A, @Nonnull final double[] x0, final int nIter,
+ @Nonnull final double[] u, @Nonnull final double[] v) {
+ Preconditions.checkArgument(A.getColumnDimension() == x0.length,
+ "Column size of A and length of x should be same");
+ Preconditions.checkArgument(A.getRowDimension() == u.length,
+ "Row size of A and length of u should be same");
+ Preconditions.checkArgument(x0.length == v.length, "Length of x and u should be same");
+ Preconditions.checkArgument(nIter >= 1, "Invalid number of iterations: " + nIter);
+
+ RealMatrix AtA = A.transpose().multiply(A);
+
+ RealVector x = new ArrayRealVector(x0);
+ for (int i = 0; i < nIter; i++) {
+ x = AtA.operate(x);
+ }
+
+ double xNorm = x.getNorm();
+ for (int i = 0, n = v.length; i < n; i++) {
+ v[i] = x.getEntry(i) / xNorm;
+ }
+
+ RealVector Av = new ArrayRealVector(A.operate(v));
+ double s = Av.getNorm();
+
+ for (int i = 0, n = u.length; i < n; i++) {
+ u[i] = Av.getEntry(i) / s;
+ }
+
+ return s;
+ }
+
+ /**
+ * Lanczos tridiagonalization for a symmetric matrix C to make s * s tridiagonal matrix T.
+ *
+ * http://www.cas.mcmaster.ca/~qiao/publications/spie05.pdf
+ *
+ * @param C target symmetric matrix
+ * @param a initial vector
+ * @param T result is stored here
+ */
+ @Nonnull
+ public static void lanczosTridiagonalization(@Nonnull final RealMatrix C, @Nonnull final double[] a,
+ @Nonnull final RealMatrix T) {
+ Preconditions.checkArgument(Arrays.deepEquals(C.getData(), C.transpose().getData()),
+ "Target matrix C must be a symmetric matrix");
+ Preconditions.checkArgument(C.getColumnDimension() == a.length,
+ "Column size of A and length of a should be same");
+ Preconditions.checkArgument(T.getRowDimension() == T.getColumnDimension(),
+ "T must be a square matrix");
+
+ int s = T.getRowDimension();
+
+ // initialize T with zeros
+ T.setSubMatrix(new double[s][s], 0, 0);
+
+ RealVector a0 = new ArrayRealVector(new double[a.length]);
+ RealVector r = new ArrayRealVector(a);
+
+ double beta0 = 1.d;
+
+ for (int i = 0; i < s; i++) {
+ RealVector a1 = r.mapDivide(beta0);
+ RealVector Ca1 = C.operate(a1);
+
+ double alpha1 = a1.dotProduct(Ca1);
+
+ r = Ca1.add(a1.mapMultiply(-1.d * alpha1)).add(a0.mapMultiply(-1.d * beta0));
+
+ double beta1 = r.getNorm();
+
+ T.setEntry(i, i, alpha1);
+ if (i - 1 >= 0) {
+ T.setEntry(i, i - 1, beta0);
+ }
+ if (i + 1 < s) {
+ T.setEntry(i, i + 1, beta1);
+ }
+
+ a0 = a1.copy();
+ beta0 = beta1;
+ }
+ }
+
+ /**
+ * QR decomposition for a tridiagonal matrix T.
+ *
+ * https://gist.github.com/lightcatcher/8118181
+ * http://www.ericmart.in/blog/optimizing_julia_tridiag_qr
+ *
+ * @param T target tridiagonal matrix
+ * @param R output matrix for R which is the same shape as T
+ * @param Qt output matrix for Q.T which is the same shape an T
+ */
+ @Nonnull
+ public static void tridiagonalQR(@Nonnull final RealMatrix T,
+ @Nonnull final RealMatrix R, @Nonnull final RealMatrix Qt) {
+ int n = T.getRowDimension();
+ Preconditions.checkArgument(n == R.getRowDimension() && n == R.getColumnDimension(),
+ "T and R must be the same shape");
+ Preconditions.checkArgument(n == Qt.getRowDimension() && n == Qt.getColumnDimension(),
+ "T and Qt must be the same shape");
+
+ // initial R = T
+ R.setSubMatrix(T.getData(), 0, 0);
+
+ // initial Qt = identity
+ Qt.setSubMatrix(new double[n][n], 0, 0);
+ for (int i = 0; i < n; i++) {
+ Qt.setEntry(i, i, 1);
+ }
+
+ for (int i = 0; i < n - 1; i++) {
+ // Householder projection for a vector x
+ // https://en.wikipedia.org/wiki/Householder_transformation
+ RealVector x = T.getSubMatrix(i, i + 1, i, i).getColumnVector(0);
+
+ double x0 = x.getEntry(0);
+ double sign = 0.d;
+ if (x0 < 0.d) {
+ sign = -1.d;
+ } else if (x0 > 0.d) {
+ sign = 1.d;
+ }
+
+ x.setEntry(0, x0 + sign * x.getNorm());
+ x = x.unitVector();
+
+ RealMatrix subR = R.getSubMatrix(i, i + 1, 0, n - 1);
+ R.setSubMatrix(subR.subtract(x.outerProduct(subR.preMultiply(x)).scalarMultiply(2)).getData(), i, 0);
+
+ RealMatrix subQt = Qt.getSubMatrix(i, i + 1, 0, n - 1);
+ Qt.setSubMatrix(subQt.subtract(x.outerProduct(subQt.preMultiply(x)).scalarMultiply(2)).getData(), i, 0);
+ }
+ }
+
+ /**
+ * Find eigenvalues and eigenvectors of given tridiagonal matrix T.
+ *
+ * http://web.csulb.edu/~tgao/math423/s94.pdf
+ * http://stats.stackexchange.com/questions/20643/finding-matrix-eigenvectors-using-qr-decomposition
+ *
+ * @param T target tridiagonal matrix
+ * @param nIter number of iterations for the QR method
+ * @param eigvals eigenvalues are stored here
+ * @param eigvecs eigenvectors are stored here
+ */
+ @Nonnull
+ public static void tridiagonalEigen(@Nonnull final RealMatrix T, @Nonnull final int nIter,
+ @Nonnull final double[] eigvals, @Nonnull final RealMatrix eigvecs) {
+ Preconditions.checkArgument(Arrays.deepEquals(T.getData(), T.transpose().getData()),
+ "Target matrix T must be a symmetric (tridiagonal) matrix");
+ Preconditions.checkArgument(eigvecs.getRowDimension() == eigvecs.getColumnDimension(),
+ "eigvecs must be a square matrix");
+ Preconditions.checkArgument(T.getRowDimension() == eigvecs.getRowDimension(),
+ "T and eigvecs must be the same shape");
+ Preconditions.checkArgument(eigvals.length == eigvecs.getRowDimension(),
+ "Number of eigenvalues and eigenvectors must be same");
+
+ int nEig = eigvals.length;
+
+ // initialize eigvecs as an identity matrix
+ eigvecs.setSubMatrix(new double[nEig][nEig], 0, 0);
+ for (int i = 0; i < nEig; i++) {
+ eigvecs.setEntry(i, i, 1);
+ }
+
+ RealMatrix T_ = T.copy();
+
+ for (int i = 0; i < nIter; i++) {
+ // QR decomposition for the tridiagonal matrix T
+ RealMatrix R = new Array2DRowRealMatrix(new double[nEig][nEig]);
+ RealMatrix Qt = new Array2DRowRealMatrix(new double[nEig][nEig]);
+ tridiagonalQR(T_, R, Qt);
+
+ RealMatrix Q = Qt.transpose();
+ T_ = R.multiply(Q);
+ eigvecs.setSubMatrix(eigvecs.multiply(Q).getData(), 0, 0);
+ }
+
+ // diagonal elements correspond to the eigenvalues
+ for (int i = 0; i < nEig; i++) {
+ eigvals[i] = T_.getEntry(i, i);
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/998203d5/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java b/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
index d4f119f..44d114d 100644
--- a/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
+++ b/core/src/test/java/hivemall/anomaly/SingularSpectrumTransformTest.java
@@ -17,6 +17,7 @@
*/
package hivemall.anomaly;
+import hivemall.anomaly.SingularSpectrumTransformUDF.ScoreFunction;
import hivemall.anomaly.SingularSpectrumTransformUDF.Parameters;
import java.io.BufferedReader;
@@ -37,8 +38,45 @@ public class SingularSpectrumTransformTest {
private static final boolean DEBUG = false;
@Test
- public void testSST() throws IOException, HiveException {
+ public void testSVDSST() throws IOException, HiveException {
+ int numChangepoints = detectSST(ScoreFunction.svd, 0.95d);
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ @Test
+ public void testIKASST() throws IOException, HiveException {
+ int numChangepoints = detectSST(ScoreFunction.ika, 0.65d);
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ @Test
+ public void testSVDTwitterData() throws IOException, HiveException {
+ int numChangepoints = detectTwitterData(ScoreFunction.svd, 0.005d);
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ @Test
+ public void testIKATwitterData() throws IOException, HiveException {
+ int numChangepoints = detectTwitterData(ScoreFunction.ika, 0.0175d);
+ Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
+ numChangepoints > 0);
+ Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
+ numChangepoints < 5);
+ }
+
+ private static int detectSST(@Nonnull final ScoreFunction scoreFunc,
+ @Nonnull final double threshold) throws IOException, HiveException {
Parameters params = new Parameters();
+ params.set(scoreFunc);
PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
SingularSpectrumTransform sst = new SingularSpectrumTransform(params, oi);
double[] outScores = new double[1];
@@ -51,19 +89,18 @@ public class SingularSpectrumTransformTest {
double x = Double.parseDouble(line);
sst.update(x, outScores);
printf("%f %f%n", x, outScores[0]);
- if (outScores[0] > 0.95d) {
+ if (outScores[0] > threshold) {
numChangepoints++;
}
}
- Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
- numChangepoints > 0);
- Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
- numChangepoints < 5);
+
+ return numChangepoints;
}
- @Test
- public void testTwitterData() throws IOException, HiveException {
+ private static int detectTwitterData(@Nonnull final ScoreFunction scoreFunc,
+ @Nonnull final double threshold) throws IOException, HiveException {
Parameters params = new Parameters();
+ params.set(scoreFunc);
PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
SingularSpectrumTransform sst = new SingularSpectrumTransform(params, oi);
double[] outScores = new double[1];
@@ -76,15 +113,13 @@ public class SingularSpectrumTransformTest {
double x = Double.parseDouble(line);
sst.update(x, outScores);
printf("%d %f %f%n", i, x, outScores[0]);
- if (outScores[0] > 0.005d) {
+ if (outScores[0] > threshold) {
numChangepoints++;
}
i++;
}
- Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints,
- numChangepoints > 0);
- Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints,
- numChangepoints < 5);
+
+ return numChangepoints;
}
private static void println(String msg) {
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/998203d5/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java b/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java
index bc960ec..b5a5e74 100644
--- a/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java
+++ b/core/src/test/java/hivemall/utils/math/MatrixUtilsTest.java
@@ -19,6 +19,7 @@ package hivemall.utils.math;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.junit.Assert;
import org.junit.Test;
@@ -205,4 +206,70 @@ public class MatrixUtilsTest {
Assert.assertArrayEquals(expected, actual);
}
+ @Test
+ public void testPower1() {
+ RealMatrix A = new Array2DRowRealMatrix(new double[][] {new double[] {1, 2, 3}, new double[] {4, 5, 6}});
+
+ double[] x = new double[3];
+ x[0] = Math.random();
+ x[1] = Math.random();
+ x[2] = Math.random();
+
+ double[] u = new double[2];
+ double[] v = new double[3];
+
+ double s = MatrixUtils.power1(A, x, 2, u, v);
+
+ SingularValueDecomposition svdA = new SingularValueDecomposition(A);
+
+ Assert.assertArrayEquals(svdA.getU().getColumn(0), u, 0.001d);
+ Assert.assertArrayEquals(svdA.getV().getColumn(0), v, 0.001d);
+ Assert.assertEquals(svdA.getSingularValues()[0], s, 0.001d);
+ }
+
+ @Test
+ public void testLanczosTridiagonalization() {
+ // Symmetric matrix
+ RealMatrix C = new Array2DRowRealMatrix(new double[][] {
+ new double[] {1, 2, 3, 4}, new double[] {2, 1, 4, 3},
+ new double[] {3, 4, 1, 2}, new double[] {4, 3, 2, 1}});
+
+ // naive initial vector
+ double[] a = new double[] {1, 1, 1, 1};
+
+ RealMatrix actual = new Array2DRowRealMatrix(new double[4][4]);
+ MatrixUtils.lanczosTridiagonalization(C, a, actual);
+
+ RealMatrix expected = new Array2DRowRealMatrix(new double[][] {
+ new double[] {40, 60, 0, 0}, new double[] {60, 10, 120, 0},
+ new double[] {0, 120, 10, 120}, new double[] {0, 0, 120, 10}});
+
+ Assert.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void testTridiagonalEigen() {
+ // Tridiagonal Matrix
+ RealMatrix T = new Array2DRowRealMatrix(new double[][] {
+ new double[] {40, 60, 0, 0}, new double[] {60, 10, 120, 0},
+ new double[] {0, 120, 10, 120}, new double[] {0, 0, 120, 10}});
+
+ double[] eigvals = new double[4];
+ RealMatrix eigvecs = new Array2DRowRealMatrix(new double[4][4]);
+
+ MatrixUtils.tridiagonalEigen(T, 2, eigvals, eigvecs);
+
+ RealMatrix actual = eigvecs.multiply(eigvecs.transpose());
+
+ RealMatrix expected = new Array2DRowRealMatrix(new double[4][4]);
+ for (int i = 0; i < 4; i++) {
+ expected.setEntry(i, i, 1);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ Assert.assertEquals(expected.getEntry(i, j), actual.getEntry(i, j), 0.001d);
+ }
+ }
+ }
}
[48/50] [abbrv] incubator-hivemall git commit: Merge branch
'feature/systemtest' of https://github.com/amaya382/hivemall into
JIRA-22/pr-336
Posted by my...@apache.org.
Merge branch 'feature/systemtest' of https://github.com/amaya382/hivemall into JIRA-22/pr-336
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/075f9348
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/075f9348
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/075f9348
Branch: refs/heads/JIRA-22/pr-336
Commit: 075f93485e41ac8f26451f68b0f38737849b04a7
Parents: 72d6a62 798ec6a
Author: myui <yu...@gmail.com>
Authored: Fri Dec 2 16:55:34 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Fri Dec 2 16:55:34 2016 +0900
----------------------------------------------------------------------
pom.xml | 1 +
systemtest/README.md | 211 +++++++++++
systemtest/pom.xml | 105 ++++++
.../java/com/klarna/hiverunner/Extractor.java | 33 ++
.../hivemall/systemtest/MsgpackConverter.java | 114 ++++++
.../exception/QueryExecutionException.java | 27 ++
.../systemtest/model/CreateTableHQ.java | 49 +++
.../hivemall/systemtest/model/DropTableHQ.java | 27 ++
.../main/java/hivemall/systemtest/model/HQ.java | 161 ++++++++
.../java/hivemall/systemtest/model/HQBase.java | 22 ++
.../hivemall/systemtest/model/InsertHQ.java | 47 +++
.../java/hivemall/systemtest/model/RawHQ.java | 30 ++
.../java/hivemall/systemtest/model/TableHQ.java | 30 ++
.../hivemall/systemtest/model/TableListHQ.java | 23 ++
.../model/UploadFileAsNewTableHQ.java | 35 ++
.../hivemall/systemtest/model/UploadFileHQ.java | 57 +++
.../model/UploadFileToExistingHQ.java | 28 ++
.../model/lazy/LazyMatchingResource.java | 63 ++++
.../systemtest/runner/HiveSystemTestRunner.java | 142 ++++++++
.../systemtest/runner/SystemTestCommonInfo.java | 46 +++
.../systemtest/runner/SystemTestRunner.java | 337 +++++++++++++++++
.../systemtest/runner/SystemTestTeam.java | 183 ++++++++++
.../systemtest/runner/TDSystemTestRunner.java | 363 +++++++++++++++++++
.../main/java/hivemall/systemtest/utils/IO.java | 83 +++++
.../resources/hivemall/hiverunner.properties | 6 +
.../src/test/resources/hivemall/td.properties | 13 +
26 files changed, 2236 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/075f9348/pom.xml
----------------------------------------------------------------------
[29/50] [abbrv] incubator-hivemall git commit: mod SNR for corner
cases
Posted by my...@apache.org.
mod SNR for corner cases
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4cfa4e5a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4cfa4e5a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4cfa4e5a
Branch: refs/heads/JIRA-22/pr-385
Commit: 4cfa4e5ac15a6535b187c23616c205696a1cd13b
Parents: 8e2842c
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 28 18:26:01 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 28 18:29:28 2016 +0900
----------------------------------------------------------------------
.../ftvec/selection/SignalNoiseRatioUDAF.java | 48 +++++--
.../selection/SignalNoiseRatioUDAFTest.java | 135 ++++++++++++++++++-
2 files changed, 167 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4cfa4e5a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
index b7b9126..507aefa 100644
--- a/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/SignalNoiseRatioUDAF.java
@@ -21,7 +21,6 @@ package hivemall.ftvec.selection;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
-import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -193,7 +192,7 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
int clazz = -1;
for (int i = 0; i < nClasses; i++) {
- int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
+ final int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
if (label == 1 && clazz == -1) {
clazz = i;
} else if (label == 1) {
@@ -255,6 +254,12 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
for (int i = 0; i < nClasses; i++) {
final long n = myAgg.ns[i];
final long m = PrimitiveObjectInspectorUtils.getLong(ns.get(i), nOI);
+
+ // no need to merge class `i`
+ if (m == 0) {
+ continue;
+ }
+
final List means = meansOI.getList(meanss.get(i));
final List variances = variancesOI.getList(variancess.get(i));
@@ -266,10 +271,19 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
final double varianceN = myAgg.variancess[i][j];
final double varianceM = PrimitiveObjectInspectorUtils.getDouble(
variances.get(j), varianceOI);
- myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m);
- myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + FastMath.pow(
- meanN - meanM, 2) * n * m / (n + m))
- / (n + m - 1);
+
+ if (n == 0) {
+ // only assign `other` into `myAgg`
+ myAgg.meanss[i][j] = meanM;
+ myAgg.variancess[i][j] = varianceM;
+ } else {
+ // merge by Chan's method
+ // http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
+ myAgg.meanss[i][j] = (n * meanN + m * meanM) / (double) (n + m);
+ myAgg.variancess[i][j] = (varianceN * (n - 1) + varianceM * (m - 1) + Math.pow(
+ meanN - meanM, 2) * n * m / (n + m))
+ / (n + m - 1);
+ }
}
}
}
@@ -302,25 +316,33 @@ public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
// calc SNR between classes each feature
final double[] result = new double[nFeatures];
- final double[] sds = new double[nClasses]; // memo
+ final double[] sds = new double[nClasses]; // for memorization
for (int i = 0; i < nFeatures; i++) {
- sds[0] = FastMath.sqrt(myAgg.variancess[0][i]);
+ sds[0] = Math.sqrt(myAgg.variancess[0][i]);
for (int j = 1; j < nClasses; j++) {
- sds[j] = FastMath.sqrt(myAgg.variancess[j][i]);
- if (Double.isNaN(sds[j])) {
+ sds[j] = Math.sqrt(myAgg.variancess[j][i]);
+ // `ns[j] == 0` means no feature entry belongs to class `j`, skip
+ if (myAgg.ns[j] == 0) {
continue;
}
for (int k = 0; k < j; k++) {
- if (Double.isNaN(sds[k])) {
+ // avoid comparing between classes having only single entry
+ if (myAgg.ns[k] == 0 || (myAgg.ns[j] == 1 && myAgg.ns[k] == 1)) {
continue;
}
- result[i] += FastMath.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i])
+
+ // SUM(snr) GROUP BY feature
+ final double snr = Math.abs(myAgg.meanss[j][i] - myAgg.meanss[k][i])
/ (sds[j] + sds[k]);
+ // if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal),
+ // regard feature `i` as meaningless between class `j` and `k` and skip
+ if (!Double.isNaN(snr)) {
+ result[i] += snr; // accept `Infinity`
+ }
}
}
}
- // SUM(snr) GROUP BY feature
return WritableUtils.toWritableList(result);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4cfa4e5a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
index 56a01d0..a4744d9 100644
--- a/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
+++ b/core/src/test/java/hivemall/ftvec/selection/SignalNoiseRatioUDAFTest.java
@@ -68,7 +68,7 @@ public class SignalNoiseRatioUDAFTest {
}
@SuppressWarnings("unchecked")
- final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg);
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
final int size = resultObj.size();
final double[] result = new double[size];
for (int i = 0; i < size; i++) {
@@ -82,7 +82,7 @@ public class SignalNoiseRatioUDAFTest {
}
@Test
- public void snrMultipleClass() throws Exception {
+ public void snrMultipleClassNormalCase() throws Exception {
// this test is based on *subset* of iris data set
final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
final ObjectInspector[] OIs = new ObjectInspector[] {
@@ -111,7 +111,7 @@ public class SignalNoiseRatioUDAFTest {
}
@SuppressWarnings("unchecked")
- final List<DoubleWritable> resultObj = (ArrayList<DoubleWritable>) evaluator.terminate(agg);
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
final int size = resultObj.size();
final double[] result = new double[size];
for (int i = 0; i < size; i++) {
@@ -125,6 +125,135 @@ public class SignalNoiseRatioUDAFTest {
}
@Test
+ public void snrMultipleClassCornerCase0() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ // all c0[0] and c1[0] are equal
+ // all c1[1] and c2[1] are equal
+ // all c*[2] are equal
+ // all c*[3] are different
+ final double[][] features = new double[][] { {3.5, 1.4, 0.3, 5.1}, {3.5, 1.5, 0.3, 5.2},
+ {3.5, 4.5, 0.3, 7.d}, {3.5, 4.5, 0.3, 6.4}, {3.3, 4.5, 0.3, 6.3}};
+
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, // class `0`
+ {0, 1, 0}, {0, 1, 0}, // class `1`
+ {0, 0, 1}}; // class `2`, only single entry
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ final double[] answer = new double[] {Double.POSITIVE_INFINITY, 121.99999999999989, 0.d,
+ 28.761904761904734};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
+ public void snrMultipleClassCornerCase1() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ final double[][] features = new double[][] { {5.1, 3.5, 1.4, 0.2}, {4.9, 3.d, 1.4, 0.2},
+ {7.d, 3.2, 4.7, 1.4}, {6.3, 3.3, 6.d, 2.5}, {6.4, 3.2, 4.5, 1.5}};
+
+ // has multiple single entries
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {1, 0, 0}, // class `0`
+ {0, 1, 0}, // class `1`, only single entry
+ {0, 0, 1}}; // class `2`, only single entry
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final List<Double> result = new ArrayList<Double>();
+ for (DoubleWritable dw : resultObj) {
+ result.add(dw.get());
+ }
+
+ Assert.assertFalse(result.contains(Double.POSITIVE_INFINITY));
+ }
+
+ @Test
+ public void snrMultipleClassCornerCase2() throws Exception {
+ final SignalNoiseRatioUDAF snr = new SignalNoiseRatioUDAF();
+ final ObjectInspector[] OIs = new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector),
+ ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableIntObjectInspector)};
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator evaluator = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator) snr.getEvaluator(new SimpleGenericUDAFParameterInfo(
+ OIs, false, false));
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, OIs);
+ final SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer agg = (SignalNoiseRatioUDAF.SignalNoiseRatioUDAFEvaluator.SignalNoiseRatioAggregationBuffer) evaluator.getNewAggregationBuffer();
+ evaluator.reset(agg);
+
+ // all [0] are equal
+ // all [1] are equal *each class*
+ final double[][] features = new double[][] { {1.d, 1.d, 1.4, 0.2}, {1.d, 1.d, 1.4, 0.2},
+ {1.d, 2.d, 4.7, 1.4}, {1.d, 2.d, 4.5, 1.5}, {1.d, 3.d, 6.d, 2.5},
+ {1.d, 3.d, 5.1, 1.9}};
+
+ final int[][] labels = new int[][] { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1},
+ {0, 0, 1}};
+
+ for (int i = 0; i < features.length; i++) {
+ final List<IntWritable> labelList = new ArrayList<IntWritable>();
+ for (int label : labels[i]) {
+ labelList.add(new IntWritable(label));
+ }
+ evaluator.iterate(agg, new Object[] {WritableUtils.toWritableList(features[i]),
+ labelList});
+ }
+
+ @SuppressWarnings("unchecked")
+ final List<DoubleWritable> resultObj = (List<DoubleWritable>) evaluator.terminate(agg);
+ final int size = resultObj.size();
+ final double[] result = new double[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = resultObj.get(i).get();
+ }
+
+ final double[] answer = new double[] {0.d, Double.POSITIVE_INFINITY, 42.94949495,
+ 33.80952381};
+
+ Assert.assertArrayEquals(answer, result, 1e-5);
+ }
+
+ @Test
public void shouldFail0() throws Exception {
expectedException.expect(UDFArgumentException.class);
[18/50] [abbrv] incubator-hivemall git commit: fix chi2
Posted by my...@apache.org.
fix chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b8cf3968
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b8cf3968
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b8cf3968
Branch: refs/heads/JIRA-22/pr-385
Commit: b8cf39684496f2511e59294041d443b9438394a9
Parents: abbf549
Author: amaya <gi...@sapphire.in.net>
Authored: Wed Sep 21 15:02:12 2016 +0900
Committer: amaya <gi...@sapphire.in.net>
Committed: Wed Sep 21 16:23:42 2016 +0900
----------------------------------------------------------------------
core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b8cf3968/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
index 951aeeb..70f0316 100644
--- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
+++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java
@@ -102,7 +102,7 @@ public class ChiSquareUDF extends GenericUDF {
// explode and transpose matrix
for (int i = 0; i < nClasses; i++) {
final Object observedObjRow = observedObj.get(i);
- final Object expectedObjRow = observedObj.get(i);
+ final Object expectedObjRow = expectedObj.get(i);
Preconditions.checkNotNull(observedObjRow);
Preconditions.checkNotNull(expectedObjRow);
[43/50] [abbrv] incubator-hivemall git commit: Merge branch
'feature/feature_selection' of https://github.com/amaya382/hivemall into
feature_selection
Posted by my...@apache.org.
Merge branch 'feature/feature_selection' of
https://github.com/amaya382/hivemall into feature_selection
# Conflicts:
# core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
# core/src/main/java/hivemall/utils/math/StatsUtils.java
# spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
# spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
# spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
# spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/67ba9631
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/67ba9631
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/67ba9631
Branch: refs/heads/JIRA-22/pr-385
Commit: 67ba9631af3c231b7abd145134d17237b6aca0a5
Parents: 69496fa ce4a489
Author: myui <yu...@gmail.com>
Authored: Mon Nov 21 18:19:45 2016 +0900
Committer: myui <yu...@gmail.com>
Committed: Mon Nov 21 18:19:45 2016 +0900
----------------------------------------------------------------------
.../hivemall/ftvec/selection/ChiSquareUDF.java | 155 ++++++++
.../ftvec/selection/SignalNoiseRatioUDAF.java | 349 +++++++++++++++++++
.../hivemall/tools/array/SelectKBestUDF.java | 143 ++++++++
.../tools/matrix/TransposeAndDotUDAF.java | 213 +++++++++++
.../java/hivemall/utils/hadoop/HiveUtils.java | 22 +-
.../java/hivemall/utils/math/StatsUtils.java | 91 +++++
.../ftvec/selection/ChiSquareUDFTest.java | 80 +++++
.../selection/SignalNoiseRatioUDAFTest.java | 348 ++++++++++++++++++
.../tools/array/SelectKBeatUDFTest.java | 65 ++++
.../tools/matrix/TransposeAndDotUDAFTest.java | 58 +++
resources/ddl/define-all-as-permanent.hive | 20 ++
resources/ddl/define-all.hive | 20 ++
resources/ddl/define-all.spark | 20 ++
resources/ddl/define-udfs.td.hql | 4 +
.../apache/spark/sql/hive/GroupedDataEx.scala | 21 ++
.../org/apache/spark/sql/hive/HivemallOps.scala | 18 +
.../spark/sql/hive/HivemallOpsSuite.scala | 100 ++++++
.../spark/sql/hive/HivemallGroupedDataset.scala | 25 ++
.../org/apache/spark/sql/hive/HivemallOps.scala | 20 ++
.../spark/sql/hive/HivemallOpsSuite.scala | 103 ++++++
20 files changed, 1873 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --cc core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index d8b1aef,c752188..8188b7a
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@@ -242,10 -240,16 +242,20 @@@ public final class HiveUtils
return category == Category.LIST;
}
+ public static boolean isMapOI(@Nonnull final ObjectInspector oi) {
+ return oi.getCategory() == Category.MAP;
+ }
+
+ public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
+ public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) {
+ return isListOI(oi)
+ && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector());
+ }
+
public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) {
return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE;
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/math/StatsUtils.java
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
----------------------------------------------------------------------
diff --cc spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
index fd4da64,2482c62..8f78a7f
--- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
+++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala
@@@ -267,13 -266,25 +267,34 @@@ final class GroupedDataEx protected[sql
}
/**
+ * @see hivemall.ftvec.trans.OnehotEncodingUDAF
+ */
+ def onehot_encoding(features: String*): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ features.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
++
++ /**
+ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
+ */
+ def snr(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyString)()))
+ }
+
+ /**
+ * @see hivemall.tools.matrix.TransposeAndDotUDAF
+ */
+ def transpose_and_dot(X: String, Y: String): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
+ Seq(X, Y).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
.toAggregateExpression()
toDF(Seq(Alias(udaf, udaf.prettyString)()))
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --cc spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 901056d,c7016c0..c231105
--- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@@ -534,30 -570,63 +575,89 @@@ final class HivemallOpsSuite extends Hi
assert(row4(0).getDouble(1) ~== 0.25)
}
+ test("user-defined aggregators for ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
+ (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
+ (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
+ .toDF("col0", "cat1", "cat2", "cat3")
+
+ val row00 = df0.groupby($"col0").onehot_encoding("cat1")
+ val row01 = df0.groupby($"col0").onehot_encoding("cat1", "cat2", "cat3")
+
+ val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0)
+ val result01 = row01.collect()(0).getAs[Row](1)
+ val result010 = result01.getAs[Map[String, Int]](0)
+ val result011 = result01.getAs[Map[String, Int]](1)
+ val result012 = result01.getAs[Map[String, Int]](2)
+
+ assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result000.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result010.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result011.keySet === Set("bird", "insect", "mammal"))
+ assert(result011.values.toSet === Set(6, 7, 8))
+ assert(result012.keySet === Set(1, 3, 9, 10, 101))
+ assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
++
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0
+ assert(df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() ===
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
----------------------------------------------------------------------
diff --cc spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
index 8ac7185,0000000..73757f6
mode 100644,000000..100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala
@@@ -1,277 -1,0 +1,302 @@@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.RelationalGroupedDataset
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.plans.logical.Pivot
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+import org.apache.spark.sql.types._
+
+/**
+ * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
+ *
+ * @groupname ensemble
+ * @groupname ftvec.trans
+ * @groupname evaluation
+ */
+final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) {
+
+ /**
+ * @see hivemall.ensemble.bagging.VotedAvgUDAF
+ * @group ensemble
+ */
+ def voted_avg(weight: String): DataFrame = {
+ // checkType(weight, NumericType)
+ val udaf = HiveUDAFFunction(
+ "voted_avg",
+ new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
+ Seq(weight).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF
+ * @group ensemble
+ */
+ def weight_voted_avg(weight: String): DataFrame = {
+ // checkType(weight, NumericType)
+ val udaf = HiveUDAFFunction(
+ "weight_voted_avg",
+ new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"),
+ Seq(weight).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.ArgminKLDistanceUDAF
+ * @group ensemble
+ */
+ def argmin_kld(weight: String, conv: String): DataFrame = {
+ // checkType(weight, NumericType)
+ // checkType(conv, NumericType)
+ val udaf = HiveUDAFFunction(
+ "argmin_kld",
+ new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"),
+ Seq(weight, conv).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.MaxValueLabelUDAF"
+ * @group ensemble
+ */
+ def max_label(score: String, label: String): DataFrame = {
+ // checkType(score, NumericType)
+ checkType(label, StringType)
+ val udaf = HiveUDAFFunction(
+ "max_label",
+ new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"),
+ Seq(score, label).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ensemble.MaxRowUDAF
+ * @group ensemble
+ */
+ def maxrow(score: String, label: String): DataFrame = {
+ // checkType(score, NumericType)
+ checkType(label, StringType)
+ val udaf = HiveUDAFFunction(
+ "maxrow",
+ new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"),
+ Seq(score, label).map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.smile.tools.RandomForestEnsembleUDAF
+ * @group ensemble
+ */
+ def rf_ensemble(predict: String): DataFrame = {
+ // checkType(predict, NumericType)
+ val udaf = HiveUDAFFunction(
+ "rf_ensemble",
+ new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"),
+ Seq(predict).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.ftvec.trans.OnehotEncodingUDAF
+ * @group ftvec.trans
+ */
+ def onehot_encoding(cols: String*): DataFrame = {
+ val udaf = HiveUDAFFunction(
+ "onehot_encoding",
+ new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"),
+ cols.map(df.col(_).expr),
+ isUDAFBridgeRequired = false)
+ .toAggregateExpression()
+ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+
+ /**
+ * @see hivemall.evaluation.MeanAbsoluteErrorUDAF
+ * @group evaluation
+ */
+ def mae(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "mae",
+ new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.MeanSquareErrorUDAF
+ * @group evaluation
+ */
+ def mse(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "mse",
+ new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.RootMeanSquareErrorUDAF
+ * @group evaluation
+ */
+ def rmse(predict: String, target: String): DataFrame = {
+ checkType(predict, FloatType)
+ checkType(target, FloatType)
+ val udaf = HiveUDAFFunction(
+ "rmse",
+ new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * @see hivemall.evaluation.FMeasureUDAF
+ * @group evaluation
+ */
+ def f1score(predict: String, target: String): DataFrame = {
+ // checkType(target, ArrayType(IntegerType))
+ // checkType(predict, ArrayType(IntegerType))
+ val udaf = HiveUDAFFunction(
+ "f1score",
+ new HiveFunctionWrapper("hivemall.evaluation.FMeasureUDAF"),
+ Seq(predict, target).map(df.col(_).expr),
+ isUDAFBridgeRequired = true)
+ .toAggregateExpression()
+ toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq)
+ }
+
+ /**
+ * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall
+ * aggregate functions, we fetch them via Java Reflections.
+ */
+ private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df")
+ private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs")
+ private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType")
+
+ private def getPrivateField[T](name: String): T = {
+ val field = groupBy.getClass.getDeclaredField(name)
+ field.setAccessible(true)
+ field.get(groupBy).asInstanceOf[T]
+ }
+
+ private def toDF(aggExprs: Seq[Expression]): DataFrame = {
+ val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+ groupingExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+
+ val aliasedAgg = aggregates.map(alias)
+
+ groupType match {
+ case RelationalGroupedDataset.GroupByType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.RollupType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.CubeType =>
+ Dataset.ofRows(
+ df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ Dataset.ofRows(
+ df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
+ }
+ }
+
+ private def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyName)()
+ }
+
+ private def checkType(colName: String, expected: DataType) = {
+ val dataType = df.resolve(colName).dataType
+ if (dataType != expected) {
+ throw new AnalysisException(
+ s""""$colName" must be $expected, however it is $dataType""")
+ }
+ }
+}
+
+object HivemallGroupedDataset {
+
+ /**
+ * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]].
+ */
+ implicit def relationalGroupedDatasetToHivemallOne(
+ groupBy: RelationalGroupedDataset): HivemallGroupedDataset = {
+ new HivemallGroupedDataset(groupBy)
++
++ /**
++ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF
++ */
++ def snr(X: String, Y: String): DataFrame = {
++ val udaf = HiveUDAFFunction(
++ "snr",
++ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"),
++ Seq(X, Y).map(df.col(_).expr),
++ isUDAFBridgeRequired = false)
++ .toAggregateExpression()
++ toDF(Seq(Alias(udaf, udaf.prettyName)()))
++ }
++
++ /**
++ * @see hivemall.tools.matrix.TransposeAndDotUDAF
++ */
++ def transpose_and_dot(X: String, Y: String): DataFrame = {
++ val udaf = HiveUDAFFunction(
++ "transpose_and_dot",
++ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"),
++ Seq(X, Y).map(df.col(_).expr),
++ isUDAFBridgeRequired = false)
++ .toAggregateExpression()
++ toDF(Seq(Alias(udaf, udaf.prettyName)()))
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --cc spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index a093e07,8446677..8bea975
--- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@@ -1,31 -1,28 +1,37 @@@
/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
*/
-
package org.apache.spark.sql.hive
+import org.apache.spark.sql.{AnalysisException, Column, Row}
+import org.apache.spark.sql.functions
+import org.apache.spark.sql.hive.HivemallGroupedDataset._
+import org.apache.spark.sql.hive.HivemallOps._
+import org.apache.spark.sql.hive.HivemallUtils._
+import org.apache.spark.sql.types._
+import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
+import org.apache.spark.test.TestDoubleWrapper._
+ import org.apache.spark.sql.hive.HivemallOps._
+ import org.apache.spark.sql.hive.HivemallUtils._
+ import org.apache.spark.sql.types._
+ import org.apache.spark.sql.{AnalysisException, Column, Row, functions}
+ import org.apache.spark.test.TestDoubleWrapper._
+ import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest}
final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest {
@@@ -636,30 -685,63 +681,88 @@@
assert(row4(0).getDouble(1) ~== 0.25)
}
+ test("user-defined aggregators for ftvec.trans") {
+ import hiveContext.implicits._
+
+ val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10),
+ (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9),
+ (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9))
+ .toDF("col0", "cat1", "cat2", "cat3")
+ val row00 = df0.groupBy($"col0").onehot_encoding("cat1")
+ val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3")
+
+ val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0)
+ val result01 = row01.collect()(0).getAs[Row](1)
+ val result010 = result01.getAs[Map[String, Int]](0)
+ val result011 = result01.getAs[Map[String, Int]](1)
+ val result012 = result01.getAs[Map[String, Int]](2)
+
+ assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result000.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog"))
+ assert(result010.values.toSet === Set(1, 2, 3, 4, 5))
+ assert(result011.keySet === Set("bird", "insect", "mammal"))
+ assert(result011.values.toSet === Set(6, 7, 8))
+ assert(result012.keySet === Set(1, 3, 9, 10, 101))
+ assert(result012.values.toSet === Set(9, 10, 11, 12, 13))
++
+ test("user-defined aggregators for ftvec.selection") {
+ import hiveContext.implicits._
+
+ // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest
+ // binary class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 4.7,3.2,1.3,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.9,3.1,4.9,1.5 | 1 |
+ // +-----------------+-------+
+ val df0 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)),
+ (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)),
+ (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect
+ (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+
+ // multiple class
+ // +-----------------+-------+
+ // | features | class |
+ // +-----------------+-------+
+ // | 5.1,3.5,1.4,0.2 | 0 |
+ // | 4.9,3.0,1.4,0.2 | 0 |
+ // | 7.0,3.2,4.7,1.4 | 1 |
+ // | 6.4,3.2,4.5,1.5 | 1 |
+ // | 6.3,3.3,6.0,2.5 | 2 |
+ // | 5.8,2.7,5.1,1.9 | 2 |
+ // +-----------------+-------+
+ val df1 = Seq(
+ (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)),
+ (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)),
+ (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1)))
+ .toDF("c0", "arg0", "arg1")
+ val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect
+ (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381))
+ .zipped
+ .foreach((actual, expected) => assert(actual ~== expected))
+ }
+
+ test("user-defined aggregators for tools.matrix") {
+ import hiveContext.implicits._
+
+ // | 1 2 3 |T | 5 6 7 |
+ // | 3 4 5 | * | 7 8 9 |
+ val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9)))
+ .toDF("c0", "arg0", "arg1")
+
+ checkAnswer(df0.groupby($"c0").transpose_and_dot("arg0", "arg1"),
+ Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0)))))
}
}