You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/06/25 10:31:23 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-253-2]
map_roulette UDF
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new d728d5a [HIVEMALL-253-2] map_roulette UDF
d728d5a is described below
commit d728d5a4b624564f003ccf7010340de67ecea713
Author: Solodye <xi...@163.com>
AuthorDate: Tue Jun 25 19:31:02 2019 +0900
[HIVEMALL-253-2] map_roulette UDF
revise #192
Author: Makoto Yui <my...@apache.org>
Closes #193 from myui/HIVEMALL-253-2.
---
.../java/hivemall/tools/map/MapRouletteUDF.java | 255 +++++++++++++++++++++
.../main/java/hivemall/utils/hadoop/HiveUtils.java | 15 +-
.../hivemall/tools/map/MapRouletteUDFTest.java | 214 +++++++++++++++++
docs/gitbook/misc/generic_funcs.md | 40 ++++
resources/ddl/define-all-as-permanent.hive | 4 +
resources/ddl/define-all.hive | 3 +
resources/ddl/define-all.spark | 3 +
7 files changed, 529 insertions(+), 5 deletions(-)
diff --git a/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
new file mode 100644
index 0000000..0729d15
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import static hivemall.utils.lang.StringUtils.join;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import com.clearspring.analytics.util.Preconditions;
+
+/**
+ * The map_roulette returns a map key based on weighted random sampling of map values.
+ */
+// @formatter:off
+@Description(name = "map_roulette",
+ value = "_FUNC_(Map<K, number> map [, (const) int/bigint seed])"
+ + " - Returns a map key based on weighted random sampling of map values."
+ + " Average of values is used for null values",
+ extended = "-- `map_roulette(map<key, number> [, integer seed])` returns key by weighted random selection\n" +
+ "SELECT \n" +
+ " map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang\n" +
+ "FROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406\n" +
+ " select 'Wang' as a, 54 as b\n" +
+ " union all\n" +
+ " select 'Zhang' as a, 21 as b\n" +
+ " union all\n" +
+ " select 'Tom' as a, 25 as b\n" +
+ ") tmp;\n" +
+ "> Wang\n" +
+ "\n" +
+ "-- Weight random selection with using filling nulls with the average value\n" +
+ "SELECT\n" +
+ " map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1\n" +
+ " map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang\n" +
+ ";\n" +
+ "\n" +
+ "-- NULL will be returned if every key is null\n" +
+ "SELECT \n" +
+ " map_roulette(map()),\n" +
+ " map_roulette(map(null, null, null, null));\n" +
+ "> NULL NULL\n" +
+ "\n" +
+ "-- Return NULL if all weights are zero\n" +
+ "SELECT\n" +
+ " map_roulette(map(1, 0)),\n" +
+ " map_roulette(map(1, 0, '5', 0))\n" +
+ ";\n" +
+ "> NULL NULL\n" +
+ "\n" +
+ "-- map_roulette does not support non-numeric weights or negative weights.\n" +
+ "SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" +
+ "> HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))\n" +
+ "SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" +
+ "> UDFArgumentException: Map value must be greather than or equals to zero: -2")
+// @formatter:on
+@UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability
+public final class MapRouletteUDF extends GenericUDF {
+
+ private transient MapObjectInspector mapOI;
+ private transient PrimitiveObjectInspector valueOI;
+ @Nullable
+ private transient PrimitiveObjectInspector seedOI;
+
+ @Nullable
+ private transient Random _rand;
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (argOIs.length != 1 && argOIs.length != 2) {
+ throw new UDFArgumentLengthException(
+ "Expected exactly one argument for map_roulette: " + argOIs.length);
+ }
+ if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) {
+ throw new UDFArgumentTypeException(0,
+ "Only map type argument is accepted but got " + argOIs[0].getTypeName());
+ }
+
+ this.mapOI = HiveUtils.asMapOI(argOIs[0]);
+ this.valueOI = HiveUtils.asDoubleCompatibleOI(mapOI.getMapValueObjectInspector());
+
+ if (argOIs.length == 2) {
+ ObjectInspector argOI1 = argOIs[1];
+ if (HiveUtils.isIntegerOI(argOI1) == false) {
+ throw new UDFArgumentException(
+ "The second argument of map_roulette must be integer type: "
+ + argOI1.getTypeName());
+ }
+ if (ObjectInspectorUtils.isConstantObjectInspector(argOI1)) {
+ long seed = HiveUtils.getAsConstLong(argOI1);
+ this._rand = new Random(seed); // fixed seed
+ } else {
+ this.seedOI = HiveUtils.asLongCompatibleOI(argOI1);
+ }
+ } else {
+ this._rand = new Random(); // random seed
+ }
+
+ return mapOI.getMapKeyObjectInspector();
+ }
+
+ @Nullable
+ @Override
+ public Object evaluate(DeferredObject[] arguments) throws HiveException {
+ Random rand = _rand;
+ if (rand == null) {
+ Object arg1 = arguments[1].get();
+ if (arg1 == null) {
+ rand = new Random();
+ } else {
+ long seed = HiveUtils.getLong(arg1, seedOI);
+ rand = new Random(seed);
+ }
+ }
+
+ Map<Object, Double> input = getObjectDoubleMap(arguments[0], mapOI, valueOI);
+ if (input == null) {
+ return null;
+ }
+
+ return rouletteWheelSelection(input, rand);
+ }
+
+ @Nullable
+ private static Map<Object, Double> getObjectDoubleMap(@Nonnull final DeferredObject argument,
+ @Nonnull final MapObjectInspector mapOI,
+ @Nonnull final PrimitiveObjectInspector valueOI) throws HiveException {
+ final Map<?, ?> m = mapOI.getMap(argument.get());
+ if (m == null) {
+ return null;
+ }
+ final int size = m.size();
+ if (size == 0) {
+ return null;
+ }
+
+ final Map<Object, Double> result = new HashMap<>(size);
+ double sum = 0.d;
+ int cnt = 0;
+ for (Map.Entry<?, ?> entry : m.entrySet()) {
+ Object key = entry.getKey();
+ if (key == null) {
+ continue;
+ }
+ Object value = entry.getValue();
+ if (value == null) {
+ continue;
+ }
+ final double v = PrimitiveObjectInspectorUtils.convertPrimitiveToDouble(value, valueOI);
+ if (v < 0.d) {
+ throw new UDFArgumentException(
+ "Map value must be greather than or equals to zero: " + entry.getValue());
+ }
+
+ result.put(key, Double.valueOf(v));
+ sum += v;
+ cnt++;
+ }
+
+ if (result.isEmpty()) {
+ return null;
+ }
+
+ if (result.size() < m.size()) {
+ // fillna with the avg value
+ final Double avg = Double.valueOf(sum / cnt);
+ for (Map.Entry<?, ?> entry : m.entrySet()) {
+ Object key = entry.getKey();
+ if (key == null) {
+ continue;
+ }
+ if (entry.getValue() == null) {
+ result.put(key, avg);
+ }
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Roulette Wheel Selection.
+ *
+ * See https://www.obitko.com/tutorials/genetic-algorithms/selection.php
+ */
+ @Nullable
+ private static Object rouletteWheelSelection(@Nonnull final Map<Object, Double> m,
+ @Nonnull final Random rnd) {
+ Preconditions.checkArgument(m.isEmpty() == false);
+
+ // 1. calculate sum
+ double sum = 0.d;
+ for (Double v : m.values()) {
+ sum += v.doubleValue();
+ }
+
+ // 2. Generate random number from interval r=[0,sum)
+ double r = rnd.nextDouble() * sum;
+
+ // 3. Go through the population and sum weight from 0 - sum s.
+ // When the sum s is greater then r, stop and return the element.
+ double s = 0.d;
+ for (Map.Entry<Object, Double> e : m.entrySet()) {
+ Object k = e.getKey();
+ double v = e.getValue().doubleValue();
+ s += v;
+ if (s > r) {
+ return k;
+ }
+ }
+
+ return null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "map_roulette(" + join(children, ',') + ")";
+ }
+
+}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index e42d1b6..9379c1e 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -482,6 +482,13 @@ public final class HiveUtils {
return PrimitiveObjectInspectorUtils.getInt(o, oi);
}
+ public static long getLong(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) {
+ if (o == null) {
+ return 0L;
+ }
+ return PrimitiveObjectInspectorUtils.getLong(o, oi);
+ }
+
@SuppressWarnings("unchecked")
@Nullable
public static <T extends Writable> T getConstValue(@Nonnull final ObjectInspector oi)
@@ -1054,9 +1061,8 @@ public final class HiveUtils {
case DECIMAL:
break;
default:
- throw new UDFArgumentTypeException(0,
- "Only numeric or string type arguments are accepted but " + argOI.getTypeName()
- + " is passed.");
+ throw new UDFArgumentTypeException(0, "Only floating point number is accepted but "
+ + argOI.getTypeName() + " is passed.");
}
return oi;
}
@@ -1080,8 +1086,7 @@ public final class HiveUtils {
break;
default:
throw new UDFArgumentTypeException(0,
- "Only numeric or string type arguments are accepted but " + argOI.getTypeName()
- + " is passed.");
+ "Only numeric argument is accepted but " + argOI.getTypeName() + " is passed.");
}
return oi;
}
diff --git a/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
new file mode 100644
index 0000000..7df1b82
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import hivemall.TestUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Unit test for {@link hivemall.tools.map.MapRouletteUDF}
+ */
+public class MapRouletteUDFTest {
+
+ /**
+ * Tom, Jerry, Amy, Wong, Zhao joined a roulette. Jerry has 0.2 weight to win. Zhao's weight is
+ * highest, he has more chance to win. During data processing, Tom's weight was lost. Algorithm
+ * treat Tom's weight as average. After 1000000 times of roulette, Zhao wins the most. Jerry
+ * wins less than Zhao but more than the other.
+ */
+ @Test
+ public void testRoulette() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+ Map<Object, Integer> solve = new HashMap<>();
+ solve.put("Tom", 0);
+ solve.put("Jerry", 0);
+ solve.put("Amy", 0);
+ solve.put("Wong", 0);
+ solve.put("Zhao", 0);
+ int T = 1000000;
+ while (T-- > 0) {
+ Map<String, Double> m = new HashMap<>();
+ m.put("Tom", 0.18); // 3rd
+ m.put("Jerry", 0.2); // 2nd
+ m.put("Amy", 0.01); // 5th
+ m.put("Wong", 0.1); // 4th
+ m.put("Zhao", 0.5); // 1st
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Object key = udf.evaluate(arguments);
+ solve.put(key, solve.get(key) + 1);
+ }
+ List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet());
+ Collections.sort(solveList, new KvComparator());
+ Object highestSolve = solveList.get(solveList.size() - 1).getKey();
+ Assert.assertEquals("Zhao", highestSolve.toString());
+ Object secondarySolve = solveList.get(solveList.size() - 2).getKey();
+ Assert.assertEquals("Jerry", secondarySolve.toString());
+ Object worseSolve = solveList.get(0).getKey();
+ Assert.assertEquals("Amy", worseSolve.toString());
+
+ udf.close();
+ }
+
+ @Test
+ public void testRouletteFillNA() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+ Map<Object, Integer> solve = new HashMap<>();
+ solve.put("Tom", 0);
+ solve.put("Jerry", 0);
+ solve.put("Amy", 0);
+ solve.put("Wong", 0);
+ solve.put("Zhao", 0);
+ int T = 1000000;
+ while (T-- > 0) {
+ Map<String, Double> m = new HashMap<>();
+ m.put("Tom", null); // (0.2+0.1+0.1+0.5)/4=0.225
+ m.put("Jerry", 0.2);
+ m.put("Amy", 0.1);
+ m.put("Wong", 0.1);
+ m.put("Zhao", 0.5);
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Object key = udf.evaluate(arguments);
+ solve.put(key, solve.get(key) + 1);
+ }
+ List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet());
+ Collections.sort(solveList, new KvComparator());
+ Object highestSolve = solveList.get(solveList.size() - 1).getKey();
+ Assert.assertEquals("Zhao", highestSolve.toString());
+ Object secondarySolve = solveList.get(solveList.size() - 2).getKey();
+ Assert.assertEquals("Tom", secondarySolve.toString());
+
+ udf.close();
+ }
+
+ private static class KvComparator implements Comparator<Map.Entry<Object, Integer>> {
+
+ @Override
+ public int compare(Map.Entry<Object, Integer> o1, Map.Entry<Object, Integer> o2) {
+ return o1.getValue().compareTo(o2.getValue());
+ }
+ }
+
+ @Test
+ public void testSerialization() throws HiveException, IOException {
+ Map<String, Double> m = new HashMap<>();
+ m.put("Tom", 0.1);
+ m.put("Jerry", 0.2);
+ m.put("Amy", 0.1);
+ m.put("Wong", 0.1);
+ m.put("Zhao", null);
+
+ TestUtils.testGenericUDFSerialization(MapRouletteUDF.class,
+ new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)},
+ new Object[] {m});
+ byte[] serialized = TestUtils.serializeObjectByKryo(new MapRouletteUDFTest());
+ TestUtils.deserializeObjectByKryo(serialized, MapRouletteUDFTest.class);
+ }
+
+ @Test
+ public void testEmptyMapAndAllNullMap() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ Map<Object, Double> m = new HashMap<>();
+ udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Assert.assertNull(udf.evaluate(arguments));
+ m.put(null, null);
+ arguments = new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Assert.assertNull(udf.evaluate(arguments));
+
+ udf.close();
+ }
+
+ @Test
+ public void testOnlyOne() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ Map<String, Double> m = new HashMap<>();
+ udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+ m.put("One", 324.6);
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Assert.assertEquals("One", udf.evaluate(arguments));
+
+ udf.close();
+ }
+
+ @Test
+ public void testSeed() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ Map<String, Double> m = new HashMap<>();
+ udf.initialize(new ObjectInspector[] {
+ ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+ ObjectInspectorUtils.getConstantObjectInspector(
+ PrimitiveObjectInspectorFactory.javaLongObjectInspector, 43L)});
+ m.put("One", 0.7);
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Assert.assertEquals("One", udf.evaluate(arguments));
+
+ udf.close();
+ }
+
+ @Test
+ public void testZeroValues() throws HiveException, IOException {
+ MapRouletteUDF udf = new MapRouletteUDF();
+ Map<String, Double> m = new HashMap<>();
+ udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+ PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+ m.put("One", 0.d);
+ m.put("Two", 0.d);
+ GenericUDF.DeferredObject[] arguments =
+ new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+ Assert.assertNull(udf.evaluate(arguments));
+
+ udf.close();
+ }
+}
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index 4f53f4d..aa4b462 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -491,6 +491,46 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
> [{"key":"one","value":1},{"key":"two","value":2}]
```
+- `map_roulette(Map<K, number> map [, (const)` int/bigint seed]) - Returns a map key based on weighted random sampling of map values. Average of values is used for null values
+ ```sql
+ -- `map_roulette(map<key, number> [, integer seed])` returns key by weighted random selection
+ SELECT
+ map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang
+ FROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406
+ select 'Wang' as a, 54 as b
+ union all
+ select 'Zhang' as a, 21 as b
+ union all
+ select 'Tom' as a, 25 as b
+ ) tmp;
+ > Wang
+
+ -- Weight random selection with using filling nulls with the average value
+ SELECT
+ map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1
+ map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang
+ ;
+
+ -- NULL will be returned if every key is null
+ SELECT
+ map_roulette(map()),
+ map_roulette(map(null, null, null, null));
+ > NULL NULL
+
+ -- Return NULL if all weights are zero
+ SELECT
+ map_roulette(map(1, 0)),
+ map_roulette(map(1, 0, '5', 0))
+ ;
+ > NULL NULL
+
+ -- map_roulette does not support non-numeric weights or negative weights.
+ SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));
+ > HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))
+ SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));
+ > UDFArgumentException: Map value must be greather than or equals to zero: -2
+ ```
+
- `map_tail_n(map SRC, int N)` - Returns the last N elements from a sorted array of SRC
- `merge_maps(Map x)` - Returns a map which contains the union of an aggregation of maps. Note that an existing value of a key can be replaced with the other duplicate key entry.
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index ff20c8c..17797a8 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -518,6 +518,9 @@ CREATE FUNCTION map_get as 'hivemall.tools.map.MapGetUDF' USING JAR '${hivemall_
DROP FUNCTION IF EXISTS map_key_values;
CREATE FUNCTION map_key_values as 'hivemall.tools.map.MapKeyValuesUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS map_roulette;
+CREATE FUNCTION map_roulette as 'hivemall.tools.map.MapRouletteUDF' USING JAR '${hivemall_jar}';
+
---------------------
-- list functions --
---------------------
@@ -880,3 +883,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U
DROP FUNCTION xgboost_multiclass_predict;
CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}';
+
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 0495113..04e8915 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -510,6 +510,9 @@ create temporary function map_get as 'hivemall.tools.map.MapGetUDF';
drop temporary function if exists map_key_values;
create temporary function map_key_values as 'hivemall.tools.map.MapKeyValuesUDF';
+drop temporary function if exists map_roulette;
+create temporary function map_roulette as 'hivemall.tools.map.MapRouletteUDF';
+
---------------------
-- list functions --
---------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index feadbbf..19f01bc 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -509,6 +509,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION map_get AS 'hivemall.tools.map.MapGetU
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS map_key_values")
sqlContext.sql("CREATE TEMPORARY FUNCTION map_key_values AS 'hivemall.tools.map.MapKeyValuesUDF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS map_roulette")
+sqlContext.sql("CREATE TEMPORARY FUNCTION map_roulette AS 'hivemall.tools.map.MapRouletteUDF'")
+
/**
* List functions
*/