You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2018/08/06 07:42:27 UTC
incubator-hivemall git commit: [HIVEMALL-210][BUGFIX] Fix a bug in
lda_predict/plsa_predict
Repository: incubator-hivemall
Updated Branches:
refs/heads/master a8a97d6e8 -> b88e9f5e0
[HIVEMALL-210][BUGFIX] Fix a bug in lda_predict/plsa_predict
## What changes were proposed in this pull request?
Fixed a bug in lda_predict/plsa_predict that duplicated term probability is [unexpectedly replaced](https://github.com/apache/incubator-hivemall/blame/a8a97d6e873d5a8a30b06f92ddc14d1ec95c2738/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java#L396)
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-210
## How was this patch tested?
unit tests and manual tests
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #154 from myui/HIVEMALL-210.
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/b88e9f5e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/b88e9f5e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/b88e9f5e
Branch: refs/heads/master
Commit: b88e9f5e0728633dc100dcbb6a701c4fee6f7268
Parents: a8a97d6
Author: Makoto Yui <my...@apache.org>
Authored: Mon Aug 6 16:42:20 2018 +0900
Committer: Makoto Yui <my...@apache.org>
Committed: Mon Aug 6 16:42:20 2018 +0900
----------------------------------------------------------------------
.../hivemall/topicmodel/LDAPredictUDAF.java | 18 ++--
.../hivemall/topicmodel/PLSAPredictUDAF.java | 22 ++---
.../hivemall/utils/struct/KeySortablePair.java | 89 ++++++++++++++++++++
.../utils/struct/ValueSortablePair.java | 85 +++++++++++++++++++
.../hivemall/topicmodel/LDAPredictUDAFTest.java | 47 +++++++++--
.../topicmodel/PLSAPredictUDAFTest.java | 47 +++++++++--
.../utils/struct/KeySortablePairTest.java | 71 ++++++++++++++++
.../utils/struct/ValueSortablePairTest.java | 71 ++++++++++++++++
8 files changed, 422 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
index 2befec1..687f20e 100644
--- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java
@@ -21,16 +21,16 @@ package hivemall.topicmodel;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.CommandLineUtils;
import hivemall.utils.lang.Primitives;
+import hivemall.utils.struct.KeySortablePair;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
import javax.annotation.Nonnull;
@@ -384,20 +384,22 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver {
myAggr.merge(wcList, lambdaMap);
}
+ @SuppressWarnings("unchecked")
@Override
public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg;
- float[] topicDistr = myAggr.get();
- SortedMap<Float, Integer> sortedDistr =
- new TreeMap<Float, Integer>(Collections.reverseOrder());
+ final float[] topicDistr = myAggr.get();
+ final KeySortablePair<Float, Integer>[] sorted =
+ new KeySortablePair[topicDistr.length];
for (int i = 0; i < topicDistr.length; i++) {
- sortedDistr.put(topicDistr[i], i);
+ sorted[i] = new KeySortablePair<>(topicDistr[i], i);
}
+ Arrays.sort(sorted, Collections.reverseOrder());
- List<Object[]> result = new ArrayList<Object[]>();
- for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+ final List<Object[]> result = new ArrayList<Object[]>(sorted.length);
+ for (KeySortablePair<Float, Integer> e : sorted) {
Object[] struct = new Object[2];
struct[0] = new IntWritable(e.getValue()); // label
struct[1] = new FloatWritable(e.getKey()); // probability
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
index d9df347..414f980 100644
--- a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
+++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java
@@ -21,16 +21,16 @@ package hivemall.topicmodel;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.CommandLineUtils;
import hivemall.utils.lang.Primitives;
+import hivemall.utils.struct.KeySortablePair;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
@@ -384,23 +384,25 @@ public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {
myAggr.merge(wcList, probMap);
}
+ @SuppressWarnings("unchecked")
@Override
public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg;
- float[] topicDistr = myAggr.get();
- SortedMap<Float, Integer> sortedDistr =
- new TreeMap<Float, Integer>(Collections.reverseOrder());
+ final float[] topicDistr = myAggr.get();
+ final KeySortablePair<Float, Integer>[] sorted =
+ new KeySortablePair[topicDistr.length];
for (int i = 0; i < topicDistr.length; i++) {
- sortedDistr.put(topicDistr[i], i);
+ sorted[i] = new KeySortablePair<>(topicDistr[i], i);
}
+ Arrays.sort(sorted, Collections.reverseOrder());
- List<Object[]> result = new ArrayList<Object[]>();
- for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) {
+ final List<Object[]> result = new ArrayList<Object[]>(sorted.length);
+ for (KeySortablePair<Float, Integer> e : sorted) {
Object[] struct = new Object[2];
- struct[0] = new IntWritable(e.getValue().intValue()); // label
- struct[1] = new FloatWritable(e.getKey().floatValue()); // probability
+ struct[0] = new IntWritable(e.getValue()); // label
+ struct[1] = new FloatWritable(e.getKey()); // probability
result.add(struct);
}
return result;
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/struct/KeySortablePair.java b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
new file mode 100644
index 0000000..f85e7e9
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/struct/KeySortablePair.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import hivemall.utils.lang.Preconditions;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+public final class KeySortablePair<K extends Comparable<K>, V>
+ implements Comparable<KeySortablePair<K, V>> {
+
+ @Nonnull
+ private final K k;
+ @Nullable
+ private final V v;
+
+ public KeySortablePair(@CheckForNull K k, @Nullable V v) {
+ this.k = Preconditions.checkNotNull(k);
+ this.v = v;
+ }
+
+ @Nonnull
+ public K getKey() {
+ return k;
+ }
+
+ @Nullable
+ public V getValue() {
+ return v;
+ }
+
+ @Override
+ public int compareTo(KeySortablePair<K, V> o) {
+ return k.compareTo(o.k);
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + k.hashCode();
+ result = prime * result + ((v == null) ? 0 : v.hashCode());
+ return result;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ KeySortablePair<K, V> other = (KeySortablePair<K, V>) obj;
+ if (!k.equals(other.k))
+ return false;
+ if (v == null) {
+ if (other.v != null)
+ return false;
+ } else if (!v.equals(other.v))
+ return false;
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ return "k=" + k + ", v=" + v;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
new file mode 100644
index 0000000..891764e
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/struct/ValueSortablePair.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import hivemall.utils.lang.Preconditions;
+
+import javax.annotation.CheckForNull;
+import javax.annotation.Nonnull;
+
+public final class ValueSortablePair<K, V extends Comparable<V>>
+ implements Comparable<ValueSortablePair<K, V>> {
+
+ @Nonnull
+ private final K k;
+ @Nonnull
+ private final V v;
+
+ public ValueSortablePair(@CheckForNull K k, @Nonnull V v) {
+ this.k = Preconditions.checkNotNull(k);
+ this.v = Preconditions.checkNotNull(v);
+ }
+
+ @Nonnull
+ public K getKey() {
+ return k;
+ }
+
+ @Nonnull
+ public V getValue() {
+ return v;
+ }
+
+ @Override
+ public int compareTo(ValueSortablePair<K, V> o) {
+ return v.compareTo(o.v);
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + k.hashCode();
+ result = prime * result + v.hashCode();
+ return result;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ ValueSortablePair<K, V> other = (ValueSortablePair<K, V>) obj;
+ if (!k.equals(other.k))
+ return false;
+ if (!v.equals(other.v))
+ return false;
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ return "k=" + k + ", v=" + v;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
index bf46485..d6a728f 100644
--- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java
@@ -20,6 +20,11 @@ package hivemall.topicmodel;
import hivemall.utils.math.MathUtils;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -27,11 +32,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.HashMap;
-
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -265,4 +265,41 @@ public class LDAPredictUDAFTest {
Assert.assertEquals(LDAUDTF.DEFAULT_TOPICS, doc2Distr.length);
Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
}
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testTerminateWithSameTopicProbability() throws Exception {
+ udaf = new LDAPredictUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.STRING),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.INT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+ evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+ agg = (LDAPredictUDAF.OnlineLDAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ // Assume that all words in a document are NOT in vocabulary that composes a LDA model.
+ // Hence, the document should be assigned to topic #1 (#2) with probability 0.5 (0.5).
+ for (int i = 0; i < 18; i++) {
+ evaluator.iterate(agg, new Object[] {words[i], 0.f, labels[i], lambdas[i]});
+ }
+
+ // Probability for each of the two topics should be same.
+ List<Object[]> result = (List<Object[]>) evaluator.terminate(agg);
+ Assert.assertEquals(result.size(), 2);
+ Assert.assertEquals(result.get(0)[1], result.get(1)[1]);
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
index 1d364ee..e61222a 100644
--- a/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
+++ b/core/src/test/java/hivemall/topicmodel/PLSAPredictUDAFTest.java
@@ -20,6 +20,11 @@ package hivemall.topicmodel;
import hivemall.utils.math.MathUtils;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
@@ -27,11 +32,6 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
-
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.HashMap;
-
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -264,4 +264,41 @@ public class PLSAPredictUDAFTest {
Assert.assertEquals(PLSAUDTF.DEFAULT_TOPICS, doc2Distr.length);
Assert.assertEquals(1.d, MathUtils.sum(doc2Distr), 1E-5d);
}
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testTerminateWithSameTopicProbability() throws Exception {
+ udaf = new PLSAPredictUDAF();
+
+ inputOIs = new ObjectInspector[] {
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.STRING),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.INT),
+ PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+ PrimitiveObjectInspector.PrimitiveCategory.FLOAT),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")};
+
+ evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+ agg = (PLSAPredictUDAF.PLSAPredictAggregationBuffer) evaluator.getNewAggregationBuffer();
+
+ evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+ evaluator.reset(agg);
+
+ // Assume that all words in a document are NOT in vocabulary that composes a LDA model.
+ // Hence, the document should be assigned to topic #1 (#2) with probability 0.5 (0.5).
+ for (int i = 0; i < words.length; i++) {
+ String word = words[i];
+ evaluator.iterate(agg, new Object[] {word, 0.f, labels[i], probs[i]});
+ }
+
+ // Probability for each of the two topics should be same.
+ List<Object[]> result = (List<Object[]>) evaluator.terminate(agg);
+ Assert.assertEquals(result.size(), 2);
+ Assert.assertEquals(result.get(0)[1], result.get(1)[1]);
+ }
+
}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
new file mode 100644
index 0000000..6f3fd70
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/struct/KeySortablePairTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.PriorityQueue;
+
+import org.junit.Test;
+
+public class KeySortablePairTest {
+
+ @Test
+ public void testPriorityQueue() {
+ KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1);
+ KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2);
+ KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3);
+ KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4);
+
+ PriorityQueue<KeySortablePair<Float, Integer>> pq =
+ new PriorityQueue<>(11, Collections.reverseOrder());
+ pq.add(v1);
+ pq.add(v2);
+ pq.add(v3);
+ pq.add(v4);
+
+ assertEquals(Float.valueOf(4.f), pq.poll().getKey());
+ assertEquals(Float.valueOf(3.f), pq.poll().getKey());
+ assertEquals(Float.valueOf(1.f), pq.poll().getKey());
+ assertEquals(Float.valueOf(-1.f), pq.poll().getKey());
+
+ assertTrue(pq.isEmpty());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testArraySort() {
+ KeySortablePair<Float, Integer> v1 = new KeySortablePair<>(3.f, 1);
+ KeySortablePair<Float, Integer> v2 = new KeySortablePair<>(1.f, 2);
+ KeySortablePair<Float, Integer> v3 = new KeySortablePair<>(4.f, 3);
+ KeySortablePair<Float, Integer> v4 = new KeySortablePair<>(-1.f, 4);
+
+ KeySortablePair<Float, Integer>[] arr = new KeySortablePair[] {v1, v2, v3, v4};
+ Arrays.sort(arr, Collections.reverseOrder());
+
+ assertEquals(Float.valueOf(4.f), arr[0].getKey());
+ assertEquals(Float.valueOf(3.f), arr[1].getKey());
+ assertEquals(Float.valueOf(1.f), arr[2].getKey());
+ assertEquals(Float.valueOf(-1.f), arr[3].getKey());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/b88e9f5e/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
new file mode 100644
index 0000000..4829279
--- /dev/null
+++ b/core/src/test/java/hivemall/utils/struct/ValueSortablePairTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.utils.struct;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.PriorityQueue;
+
+import org.junit.Test;
+
+public class ValueSortablePairTest {
+
+ @Test
+ public void testPriorityQueue() {
+ ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, -1);
+ ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3);
+ ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2);
+ ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0);
+
+ PriorityQueue<ValueSortablePair<Float, Integer>> pq =
+ new PriorityQueue<>(11, Collections.reverseOrder());
+ pq.add(v1);
+ pq.add(v2);
+ pq.add(v3);
+ pq.add(v4);
+
+ assertEquals(3, pq.poll().getValue().intValue());
+ assertEquals(2, pq.poll().getValue().intValue());
+ assertEquals(0, pq.poll().getValue().intValue());
+ assertEquals(-1, pq.poll().getValue().intValue());
+
+ assertTrue(pq.isEmpty());
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testArraySort() {
+ ValueSortablePair<Float, Integer> v1 = new ValueSortablePair<>(1.f, -1);
+ ValueSortablePair<Float, Integer> v2 = new ValueSortablePair<>(2.f, 3);
+ ValueSortablePair<Float, Integer> v3 = new ValueSortablePair<>(3.f, 2);
+ ValueSortablePair<Float, Integer> v4 = new ValueSortablePair<>(4.f, 0);
+
+ ValueSortablePair<Float, Integer>[] arr = new ValueSortablePair[] {v1, v2, v3, v4};
+ Arrays.sort(arr, Collections.reverseOrder());
+
+ assertEquals(3, arr[0].getValue().intValue());
+ assertEquals(2, arr[1].getValue().intValue());
+ assertEquals(0, arr[2].getValue().intValue());
+ assertEquals(-1, arr[3].getValue().intValue());
+ }
+
+}