You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/06/25 10:31:23 UTC

[incubator-hivemall] branch master updated: [HIVEMALL-253-2] map_roulette UDF

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new d728d5a  [HIVEMALL-253-2] map_roulette UDF
d728d5a is described below

commit d728d5a4b624564f003ccf7010340de67ecea713
Author: Solodye <xi...@163.com>
AuthorDate: Tue Jun 25 19:31:02 2019 +0900

    [HIVEMALL-253-2] map_roulette UDF
    
    revise #192
    
    Author: Makoto Yui <my...@apache.org>
    
    Closes #193 from myui/HIVEMALL-253-2.
---
 .../java/hivemall/tools/map/MapRouletteUDF.java    | 255 +++++++++++++++++++++
 .../main/java/hivemall/utils/hadoop/HiveUtils.java |  15 +-
 .../hivemall/tools/map/MapRouletteUDFTest.java     | 214 +++++++++++++++++
 docs/gitbook/misc/generic_funcs.md                 |  40 ++++
 resources/ddl/define-all-as-permanent.hive         |   4 +
 resources/ddl/define-all.hive                      |   3 +
 resources/ddl/define-all.spark                     |   3 +
 7 files changed, 529 insertions(+), 5 deletions(-)

diff --git a/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
new file mode 100644
index 0000000..0729d15
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import static hivemall.utils.lang.StringUtils.join;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+
+import com.clearspring.analytics.util.Preconditions;
+
+/**
+ * The map_roulette returns a map key based on weighted random sampling of map values.
+ */
+// @formatter:off
+@Description(name = "map_roulette",
+        value = "_FUNC_(Map<K, number> map [, (const) int/bigint seed])"
+                + " - Returns a map key based on weighted random sampling of map values."
+                + " Average of values is used for null values",
+        extended = "-- `map_roulette(map<key, number> [, integer seed])` returns key by weighted random selection\n" + 
+                "SELECT \n" + 
+                "  map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang\n" + 
+                "FROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406\n" + 
+                "  select 'Wang' as a, 54 as b\n" + 
+                "  union all\n" + 
+                "  select 'Zhang' as a, 21 as b\n" + 
+                "  union all\n" + 
+                "  select 'Tom' as a, 25 as b\n" + 
+                ") tmp;\n" + 
+                "> Wang\n" + 
+                "\n" + 
+                "-- Weight random selection with using filling nulls with the average value\n" + 
+                "SELECT\n" + 
+                "  map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1\n" + 
+                "  map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang\n" + 
+                ";\n" + 
+                "\n" + 
+                "-- NULL will be returned if every key is null\n" + 
+                "SELECT \n" + 
+                "  map_roulette(map()),\n" + 
+                "  map_roulette(map(null, null, null, null));\n" + 
+                "> NULL    NULL\n" + 
+                "\n" + 
+                "-- Return NULL if all weights are zero\n" + 
+                "SELECT\n" + 
+                "  map_roulette(map(1, 0)),\n" + 
+                "  map_roulette(map(1, 0, '5', 0))\n" + 
+                ";\n" + 
+                "> NULL    NULL\n" + 
+                "\n" + 
+                "-- map_roulette does not support non-numeric weights or negative weights.\n" + 
+                "SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" + 
+                "> HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))\n" + 
+                "SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" + 
+                "> UDFArgumentException: Map value must be greather than or equals to zero: -2")
+// @formatter:on
+@UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability
+public final class MapRouletteUDF extends GenericUDF {
+
+    private transient MapObjectInspector mapOI;
+    private transient PrimitiveObjectInspector valueOI;
+    @Nullable
+    private transient PrimitiveObjectInspector seedOI;
+
+    @Nullable
+    private transient Random _rand;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+        if (argOIs.length != 1 && argOIs.length != 2) {
+            throw new UDFArgumentLengthException(
+                "Expected exactly one argument for map_roulette: " + argOIs.length);
+        }
+        if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) {
+            throw new UDFArgumentTypeException(0,
+                "Only map type argument is accepted but got " + argOIs[0].getTypeName());
+        }
+
+        this.mapOI = HiveUtils.asMapOI(argOIs[0]);
+        this.valueOI = HiveUtils.asDoubleCompatibleOI(mapOI.getMapValueObjectInspector());
+
+        if (argOIs.length == 2) {
+            ObjectInspector argOI1 = argOIs[1];
+            if (HiveUtils.isIntegerOI(argOI1) == false) {
+                throw new UDFArgumentException(
+                    "The second argument of map_roulette must be integer type: "
+                            + argOI1.getTypeName());
+            }
+            if (ObjectInspectorUtils.isConstantObjectInspector(argOI1)) {
+                long seed = HiveUtils.getAsConstLong(argOI1);
+                this._rand = new Random(seed); // fixed seed
+            } else {
+                this.seedOI = HiveUtils.asLongCompatibleOI(argOI1);
+            }
+        } else {
+            this._rand = new Random(); // random seed
+        }
+
+        return mapOI.getMapKeyObjectInspector();
+    }
+
+    @Nullable
+    @Override
+    public Object evaluate(DeferredObject[] arguments) throws HiveException {
+        Random rand = _rand;
+        if (rand == null) {
+            Object arg1 = arguments[1].get();
+            if (arg1 == null) {
+                rand = new Random();
+            } else {
+                long seed = HiveUtils.getLong(arg1, seedOI);
+                rand = new Random(seed);
+            }
+        }
+
+        Map<Object, Double> input = getObjectDoubleMap(arguments[0], mapOI, valueOI);
+        if (input == null) {
+            return null;
+        }
+
+        return rouletteWheelSelection(input, rand);
+    }
+
+    @Nullable
+    private static Map<Object, Double> getObjectDoubleMap(@Nonnull final DeferredObject argument,
+            @Nonnull final MapObjectInspector mapOI,
+            @Nonnull final PrimitiveObjectInspector valueOI) throws HiveException {
+        final Map<?, ?> m = mapOI.getMap(argument.get());
+        if (m == null) {
+            return null;
+        }
+        final int size = m.size();
+        if (size == 0) {
+            return null;
+        }
+
+        final Map<Object, Double> result = new HashMap<>(size);
+        double sum = 0.d;
+        int cnt = 0;
+        for (Map.Entry<?, ?> entry : m.entrySet()) {
+            Object key = entry.getKey();
+            if (key == null) {
+                continue;
+            }
+            Object value = entry.getValue();
+            if (value == null) {
+                continue;
+            }
+            final double v = PrimitiveObjectInspectorUtils.convertPrimitiveToDouble(value, valueOI);
+            if (v < 0.d) {
+                throw new UDFArgumentException(
+                    "Map value must be greather than or equals to zero: " + entry.getValue());
+            }
+
+            result.put(key, Double.valueOf(v));
+            sum += v;
+            cnt++;
+        }
+
+        if (result.isEmpty()) {
+            return null;
+        }
+
+        if (result.size() < m.size()) {
+            // fillna with the avg value
+            final Double avg = Double.valueOf(sum / cnt);
+            for (Map.Entry<?, ?> entry : m.entrySet()) {
+                Object key = entry.getKey();
+                if (key == null) {
+                    continue;
+                }
+                if (entry.getValue() == null) {
+                    result.put(key, avg);
+                }
+            }
+        }
+
+        return result;
+    }
+
+    /**
+     * Roulette Wheel Selection.
+     * 
+     * See https://www.obitko.com/tutorials/genetic-algorithms/selection.php
+     */
+    @Nullable
+    private static Object rouletteWheelSelection(@Nonnull final Map<Object, Double> m,
+            @Nonnull final Random rnd) {
+        Preconditions.checkArgument(m.isEmpty() == false);
+
+        // 1. calculate sum
+        double sum = 0.d;
+        for (Double v : m.values()) {
+            sum += v.doubleValue();
+        }
+
+        // 2. Generate random number from interval r=[0,sum)
+        double r = rnd.nextDouble() * sum;
+
+        // 3. Go through the population and sum weight from 0 - sum s.
+        //    When the sum s is greater then r, stop and return the element.
+        double s = 0.d;
+        for (Map.Entry<Object, Double> e : m.entrySet()) {
+            Object k = e.getKey();
+            double v = e.getValue().doubleValue();
+            s += v;
+            if (s > r) {
+                return k;
+            }
+        }
+
+        return null;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return "map_roulette(" + join(children, ',') + ")";
+    }
+
+}
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index e42d1b6..9379c1e 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -482,6 +482,13 @@ public final class HiveUtils {
         return PrimitiveObjectInspectorUtils.getInt(o, oi);
     }
 
+    public static long getLong(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) {
+        if (o == null) {
+            return 0L;
+        }
+        return PrimitiveObjectInspectorUtils.getLong(o, oi);
+    }
+
     @SuppressWarnings("unchecked")
     @Nullable
     public static <T extends Writable> T getConstValue(@Nonnull final ObjectInspector oi)
@@ -1054,9 +1061,8 @@ public final class HiveUtils {
             case DECIMAL:
                 break;
             default:
-                throw new UDFArgumentTypeException(0,
-                    "Only numeric or string type arguments are accepted but " + argOI.getTypeName()
-                            + " is passed.");
+                throw new UDFArgumentTypeException(0, "Only floating point number is accepted but "
+                        + argOI.getTypeName() + " is passed.");
         }
         return oi;
     }
@@ -1080,8 +1086,7 @@ public final class HiveUtils {
                 break;
             default:
                 throw new UDFArgumentTypeException(0,
-                    "Only numeric or string type arguments are accepted but " + argOI.getTypeName()
-                            + " is passed.");
+                    "Only numeric argument is accepted but " + argOI.getTypeName() + " is passed.");
         }
         return oi;
     }
diff --git a/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
new file mode 100644
index 0000000..7df1b82
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import hivemall.TestUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Unit test for {@link hivemall.tools.map.MapRouletteUDF}
+ */
+public class MapRouletteUDFTest {
+
+    /**
+     * Tom, Jerry, Amy, Wong, Zhao joined a roulette. Jerry has 0.2 weight to win. Zhao's weight is
+     * highest, he has more chance to win. During data processing, Tom's weight was lost. Algorithm
+     * treat Tom's weight as average. After 1000000 times of roulette, Zhao wins the most. Jerry
+     * wins less than Zhao but more than the other.
+     */
+    @Test
+    public void testRoulette() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        Map<Object, Integer> solve = new HashMap<>();
+        solve.put("Tom", 0);
+        solve.put("Jerry", 0);
+        solve.put("Amy", 0);
+        solve.put("Wong", 0);
+        solve.put("Zhao", 0);
+        int T = 1000000;
+        while (T-- > 0) {
+            Map<String, Double> m = new HashMap<>();
+            m.put("Tom", 0.18); // 3rd
+            m.put("Jerry", 0.2); // 2nd
+            m.put("Amy", 0.01); // 5th
+            m.put("Wong", 0.1); // 4th
+            m.put("Zhao", 0.5); // 1st
+            GenericUDF.DeferredObject[] arguments =
+                    new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+            Object key = udf.evaluate(arguments);
+            solve.put(key, solve.get(key) + 1);
+        }
+        List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet());
+        Collections.sort(solveList, new KvComparator());
+        Object highestSolve = solveList.get(solveList.size() - 1).getKey();
+        Assert.assertEquals("Zhao", highestSolve.toString());
+        Object secondarySolve = solveList.get(solveList.size() - 2).getKey();
+        Assert.assertEquals("Jerry", secondarySolve.toString());
+        Object worseSolve = solveList.get(0).getKey();
+        Assert.assertEquals("Amy", worseSolve.toString());
+
+        udf.close();
+    }
+
+    @Test
+    public void testRouletteFillNA() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        Map<Object, Integer> solve = new HashMap<>();
+        solve.put("Tom", 0);
+        solve.put("Jerry", 0);
+        solve.put("Amy", 0);
+        solve.put("Wong", 0);
+        solve.put("Zhao", 0);
+        int T = 1000000;
+        while (T-- > 0) {
+            Map<String, Double> m = new HashMap<>();
+            m.put("Tom", null); // (0.2+0.1+0.1+0.5)/4=0.225
+            m.put("Jerry", 0.2);
+            m.put("Amy", 0.1);
+            m.put("Wong", 0.1);
+            m.put("Zhao", 0.5);
+            GenericUDF.DeferredObject[] arguments =
+                    new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+            Object key = udf.evaluate(arguments);
+            solve.put(key, solve.get(key) + 1);
+        }
+        List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet());
+        Collections.sort(solveList, new KvComparator());
+        Object highestSolve = solveList.get(solveList.size() - 1).getKey();
+        Assert.assertEquals("Zhao", highestSolve.toString());
+        Object secondarySolve = solveList.get(solveList.size() - 2).getKey();
+        Assert.assertEquals("Tom", secondarySolve.toString());
+
+        udf.close();
+    }
+
+    private static class KvComparator implements Comparator<Map.Entry<Object, Integer>> {
+
+        @Override
+        public int compare(Map.Entry<Object, Integer> o1, Map.Entry<Object, Integer> o2) {
+            return o1.getValue().compareTo(o2.getValue());
+        }
+    }
+
+    @Test
+    public void testSerialization() throws HiveException, IOException {
+        Map<String, Double> m = new HashMap<>();
+        m.put("Tom", 0.1);
+        m.put("Jerry", 0.2);
+        m.put("Amy", 0.1);
+        m.put("Wong", 0.1);
+        m.put("Zhao", null);
+
+        TestUtils.testGenericUDFSerialization(MapRouletteUDF.class,
+            new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)},
+            new Object[] {m});
+        byte[] serialized = TestUtils.serializeObjectByKryo(new MapRouletteUDFTest());
+        TestUtils.deserializeObjectByKryo(serialized, MapRouletteUDFTest.class);
+    }
+
+    @Test
+    public void testEmptyMapAndAllNullMap() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<Object, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertNull(udf.evaluate(arguments));
+        m.put(null, null);
+        arguments = new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertNull(udf.evaluate(arguments));
+
+        udf.close();
+    }
+
+    @Test
+    public void testOnlyOne() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<String, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        m.put("One", 324.6);
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertEquals("One", udf.evaluate(arguments));
+
+        udf.close();
+    }
+
+    @Test
+    public void testSeed() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<String, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {
+                ObjectInspectorFactory.getStandardMapObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
+                ObjectInspectorUtils.getConstantObjectInspector(
+                    PrimitiveObjectInspectorFactory.javaLongObjectInspector, 43L)});
+        m.put("One", 0.7);
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertEquals("One", udf.evaluate(arguments));
+
+        udf.close();
+    }
+
+    @Test
+    public void testZeroValues() throws HiveException, IOException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<String, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        m.put("One", 0.d);
+        m.put("Two", 0.d);
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertNull(udf.evaluate(arguments));
+
+        udf.close();
+    }
+}
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index 4f53f4d..aa4b462 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -491,6 +491,46 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
   > [{"key":"one","value":1},{"key":"two","value":2}]
   ```
 
+- `map_roulette(Map<K, number> map [, (const)` int/bigint seed]) - Returns a map key based on weighted random sampling of map values. Average of values is used for null values
+  ```sql
+  -- `map_roulette(map<key, number> [, integer seed])` returns key by weighted random selection
+  SELECT 
+    map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang
+  FROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406
+    select 'Wang' as a, 54 as b
+    union all
+    select 'Zhang' as a, 21 as b
+    union all
+    select 'Tom' as a, 25 as b
+  ) tmp;
+  > Wang
+
+  -- Weight random selection with using filling nulls with the average value
+  SELECT
+    map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1
+    map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang
+  ;
+
+  -- NULL will be returned if every key is null
+  SELECT 
+    map_roulette(map()),
+    map_roulette(map(null, null, null, null));
+  > NULL    NULL
+
+  -- Return NULL if all weights are zero
+  SELECT
+    map_roulette(map(1, 0)),
+    map_roulette(map(1, 0, '5', 0))
+  ;
+  > NULL    NULL
+
+  -- map_roulette does not support non-numeric weights or negative weights.
+  SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));
+  > HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))
+  SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));
+  > UDFArgumentException: Map value must be greather than or equals to zero: -2
+  ```
+
 - `map_tail_n(map SRC, int N)` - Returns the last N elements from a sorted array of SRC
 
 - `merge_maps(Map x)` - Returns a map which contains the union of an aggregation of maps. Note that an existing value of a key can be replaced with the other duplicate key entry.
diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive
index ff20c8c..17797a8 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -518,6 +518,9 @@ CREATE FUNCTION map_get as 'hivemall.tools.map.MapGetUDF' USING JAR '${hivemall_
 DROP FUNCTION IF EXISTS map_key_values;
 CREATE FUNCTION map_key_values as 'hivemall.tools.map.MapKeyValuesUDF' USING JAR '${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS map_roulette;
+CREATE FUNCTION map_roulette as 'hivemall.tools.map.MapRouletteUDF' USING JAR '${hivemall_jar}';
+
 ---------------------
 -- list functions --
 ---------------------
@@ -880,3 +883,4 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U
 
 DROP FUNCTION xgboost_multiclass_predict;
 CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}';
+
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 0495113..04e8915 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -510,6 +510,9 @@ create temporary function map_get as 'hivemall.tools.map.MapGetUDF';
 drop temporary function if exists map_key_values;
 create temporary function map_key_values as 'hivemall.tools.map.MapKeyValuesUDF';
 
+drop temporary function if exists map_roulette;
+create temporary function map_roulette as 'hivemall.tools.map.MapRouletteUDF';
+
 ---------------------
 -- list functions --
 ---------------------
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index feadbbf..19f01bc 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -509,6 +509,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION map_get AS 'hivemall.tools.map.MapGetU
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS map_key_values")
 sqlContext.sql("CREATE TEMPORARY FUNCTION map_key_values AS 'hivemall.tools.map.MapKeyValuesUDF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS map_roulette")
+sqlContext.sql("CREATE TEMPORARY FUNCTION map_roulette AS 'hivemall.tools.map.MapRouletteUDF'")
+
 /**
  * List functions
  */