You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/06/10 06:51:16 UTC

[incubator-hivemall] 01/10: merged from https://github.com/Solodye/incubator-hivemall.git master ignoring pom.xml updates

This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch HIVEMALL-253-2
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git

commit 6d708abbb46bc740b52bab09ba4eda943dadaf85
Author: Makoto Yui <my...@apache.org>
AuthorDate: Mon Jun 10 15:29:26 2019 +0900

    merged from https://github.com/Solodye/incubator-hivemall.git master ignoring pom.xml updates
---
 .../java/hivemall/tools/map/MapRouletteUDF.java    | 192 +++++++++++++++++++++
 .../hivemall/tools/map/MapRouletteUDFTest.java     | 148 ++++++++++++++++
 docs/gitbook/misc/generic_funcs.md                 |  38 +++-
 resources/ddl/define-all.hive                      |   3 +
 4 files changed, 380 insertions(+), 1 deletion(-)

diff --git a/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
new file mode 100644
index 0000000..e69dd53
--- /dev/null
+++ b/core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import hivemall.utils.hadoop.HiveUtils;
+import org.apache.hadoop.hive.ql.exec.*;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import java.util.*;
+import static hivemall.HivemallConstants.*;
+
+/**
+ * The map_roulette() can be use to do roulette, according to each map.entry 's weight.
+ * 
+ * @author Wang, Yizheng
+ */
+@Description(name = "map_roulette", value = "_FUNC_(Map<K, number> map)"
+        + " - Returns the key K which determine to its weight , the bigger weight is ,the more probability K will return. "
+        + "Number is a probability value or a positive weight")
+@UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability
+public class MapRouletteUDF extends GenericUDF {
+
+    /**
+     * The map passed in saved all the value and its weight
+     *
+     * @param m A map contains a lot of item as key, with their weight as value
+     * @return The key that computer selected according to key's weight
+     */
+    private Object algorithm(Map<Object, Double> m) {
+        // normalize the weight
+        double sum = 0;
+        for (Map.Entry<Object, Double> entry : m.entrySet()) {
+            sum += entry.getValue();
+        }
+        for (Map.Entry<Object, Double> entry : m.entrySet()) {
+            entry.setValue(entry.getValue() / sum);
+        }
+
+        // sort and generate a number axis
+        List<Map.Entry<Object, Double>> entryList = new ArrayList<>(m.entrySet());
+        Collections.sort(entryList, new MapRouletteUDF.KvComparator());
+        double tmp = 0;
+        for (Map.Entry<Object, Double> entry : entryList) {
+            tmp += entry.getValue();
+            entry.setValue(tmp);
+        }
+
+        // judge last value
+        if (entryList.get(entryList.size() - 1).getValue() > 1.0) {
+            entryList.get(entryList.size() - 1).setValue(1.0);
+        }
+
+        // pick a Object base on its weight
+        double cursor = Math.random();
+        for (Map.Entry<Object, Double> entry : entryList) {
+            if (cursor < entry.getValue()) {
+                return entry.getKey();
+            }
+        }
+        return null;
+    }
+
+    private transient MapObjectInspector mapOI;
+    private transient PrimitiveObjectInspector valueOI;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
+        if (arguments.length != 1)
+            throw new UDFArgumentLengthException(
+                "Expected one arguments for map_find_max_prob: " + arguments.length);
+        if (arguments[0].getCategory() != ObjectInspector.Category.MAP) {
+            throw new UDFArgumentTypeException(0,
+                "Only map type arguments are accepted for the key but " + arguments[0].getTypeName()
+                        + " was passed as parameter 1.");
+        }
+        mapOI = HiveUtils.asMapOI(arguments[0]);
+        ObjectInspector keyOI = mapOI.getMapKeyObjectInspector();
+
+        //judge valueOI is a number
+        valueOI = (PrimitiveObjectInspector) mapOI.getMapValueObjectInspector();
+        switch (valueOI.getTypeName()) {
+            case INT_TYPE_NAME:
+            case DOUBLE_TYPE_NAME:
+            case BIGINT_TYPE_NAME:
+            case FLOAT_TYPE_NAME:
+            case SMALLINT_TYPE_NAME:
+            case TINYINT_TYPE_NAME:
+            case DECIMAL_TYPE_NAME:
+            case STRING_TYPE_NAME:
+                // Pass an empty map or a map full of {null, null} will get string type
+                // An number in string format like "3.5" also support
+                break;
+            default:
+                throw new UDFArgumentException(
+                    "Expected a number but get: " + valueOI.getTypeName());
+        }
+        return keyOI;
+    }
+
+    @Override
+    public Object evaluate(DeferredObject[] arguments) throws HiveException {
+        Map<Object, Double> input = processObjectDoubleMap(arguments[0]);
+        if (input == null) {
+            return null;
+        }
+        // handle empty map
+        if (input.isEmpty()) {
+            return null;
+        }
+        return algorithm(input);
+    }
+
+    /**
+     * Process the data passed by user.
+     * 
+     * @param argument data passed by user
+     * @return If all the value is ,
+     * @throws HiveException If get the wrong weight value like {key = "Wang", value = "Zhang"},
+     *         "Zhang" isn't a number ,this Method will throw exception when
+     *         convertPrimitiveToDouble("Zhang", valueOD)
+     */
+    private Map<Object, Double> processObjectDoubleMap(DeferredObject argument)
+            throws HiveException {
+        // get
+        Map<?, ?> m = mapOI.getMap(argument.get());
+        if (m == null) {
+            return null;
+        }
+        if (m.size() == 0) {
+            return null;
+        }
+        // convert
+        Map<Object, Double> input = new HashMap<>();
+        Double avg = 0.0;
+        for (Map.Entry<?, ?> entry : m.entrySet()) {
+            Object key = entry.getKey();
+            Double value = null;
+            if (entry.getValue() != null) {
+                value = PrimitiveObjectInspectorUtils.convertPrimitiveToDouble(entry.getValue(),
+                    valueOI);
+                if (value < 0) {
+                    throw new UDFArgumentException(entry.getValue() + " < 0");
+                }
+                avg += value;
+            }
+            input.put(key, value);
+        }
+        avg /= m.size();
+        for (Map.Entry<?, ?> entry : input.entrySet()) {
+            if (entry.getValue() == null) {
+                Object key = entry.getKey();
+                input.put(key, avg);
+            }
+        }
+        return input;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return "map_roulette(" + Arrays.toString(children) + ")";
+    }
+
+    private static class KvComparator implements Comparator<Map.Entry<Object, Double>> {
+
+        @Override
+        public int compare(Map.Entry<Object, Double> o1, Map.Entry<Object, Double> o2) {
+            return o1.getValue().compareTo(o2.getValue());
+        }
+    }
+
+}
diff --git a/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
new file mode 100644
index 0000000..a7497d8
--- /dev/null
+++ b/core/src/test/java/hivemall/tools/map/MapRouletteUDFTest.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.tools.map;
+
+import hivemall.TestUtils;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.junit.Assert;
+import org.junit.Test;
+import java.io.IOException;
+import java.util.*;
+
+/**
+ * Unit test for {@link hivemall.tools.map.MapRouletteUDF}
+ * 
+ * @author Wang, Yizheng
+ */
+public class MapRouletteUDFTest {
+
+    /**
+     * Tom, Jerry, Amy, Wong, Zhao joined a roulette. Jerry has 0.2 weight to win. Zhao's weight is
+     * highest, he has more chance to win. During data processing ,Tom 's weight was Lost. Algorithm
+     * treat Tom 's weight as average. After 1000000 times of roulette, Zhao wins the most. Jerry
+     * wins less than Zhao but more than the other.
+     *
+     * @throws HiveException fmp.initialize may throws UDFArgumentException when checking parameter,
+     *         org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector#getMap(java.lang.Object)
+     *         may throw Hive Exception
+     */
+    @Test
+    public void testRoulette() throws HiveException {
+        MapRouletteUDF fmp = new MapRouletteUDF();
+        fmp.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        Map<Object, Integer> solve = new HashMap<>();
+        solve.put("Tom", 0);
+        solve.put("Jerry", 0);
+        solve.put("Amy", 0);
+        solve.put("Wong", 0);
+        solve.put("Zhao", 0);
+        int T = 1000000;
+        while (T-- > 0) {
+            Map<Object, Double> m = new HashMap<>();
+            m.put("Tom", null);
+            m.put("Jerry", 0.2);
+            m.put("Amy", 0.1);
+            m.put("Wong", 0.1);
+            m.put("Zhao", 0.5);
+            GenericUDF.DeferredObject[] arguments =
+                    new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+            Object key = fmp.evaluate(arguments);
+            solve.put(key, solve.get(key) + 1);
+        }
+        List<Map.Entry<Object, Integer>> solveList = new ArrayList<>(solve.entrySet());
+        Collections.sort(solveList, new KvComparator());
+        Object highestSolve = solveList.get(solveList.size() - 1).getKey();
+        Assert.assertEquals(highestSolve.toString(), "Zhao");
+        Object secondarySolve = solveList.get(solveList.size() - 2).getKey();
+        Assert.assertEquals(secondarySolve.toString(), "Jerry");
+    }
+
+    private static class KvComparator implements Comparator<Map.Entry<Object, Integer>> {
+
+        @Override
+        public int compare(Map.Entry<Object, Integer> o1, Map.Entry<Object, Integer> o2) {
+            return o1.getValue().compareTo(o2.getValue());
+        }
+    }
+
+    @Test
+    public void testSerialization() throws HiveException, IOException {
+        Map<Object, Double> m = new HashMap<>();
+        m.put("Tom", 0.1);
+        m.put("Jerry", 0.2);
+        m.put("Amy", 0.1);
+        m.put("Wong", 0.1);
+        m.put("Zhao", null);
+
+        TestUtils.testGenericUDFSerialization(MapRouletteUDF.class,
+            new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+                PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)},
+            new Object[] {m});
+        byte[] serialized = TestUtils.serializeObjectByKryo(new MapRouletteUDFTest());
+        TestUtils.deserializeObjectByKryo(serialized, MapRouletteUDFTest.class);
+    }
+
+    @Test
+    public void testEmptyMapAndAllNullMap() throws HiveException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<Object, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertNull(udf.evaluate(arguments));
+        m.put(null, null);
+        arguments = new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertNull(udf.evaluate(arguments));
+    }
+
+    @Test
+    public void testOnlyOne() throws HiveException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<Object, Double> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaDoubleObjectInspector)});
+        m.put("One", 324.6);
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertEquals(udf.evaluate(arguments), "One");
+    }
+
+    @Test
+    public void testString() throws HiveException {
+        MapRouletteUDF udf = new MapRouletteUDF();
+        Map<Object, String> m = new HashMap<>();
+        udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector)});
+        m.put("One", "0.7");
+        GenericUDF.DeferredObject[] arguments =
+                new GenericUDF.DeferredObject[] {new GenericUDF.DeferredJavaObject(m)};
+        Assert.assertEquals(udf.evaluate(arguments), "One");
+    }
+}
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index 4f53f4d..328969b 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -539,7 +539,43 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
       to_ordered_map(key, value, -100)    -- {3:"banana",4:"candy",10:"apple"} (tail-100)
   from t
   ```
-
+  
+- `map_roulette(Map<key, number> map)` -  Returns the `key` which determine to its `number` weight, the bigger weight is ,the more probability K will return.`Number` is a probability value or a positive weight
+  
+  We can use `map_roulette()` on a `Map<key, number>` that was secured from data.
+  ```sql
+  select map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang
+  from(
+      select 'Wang' as a, 54 as b
+      union
+      select 'Zhang' as a, 21 as b
+      union
+      select 'Tom' as a, 25 as b
+  )tmp;
+  ```
+  We can pass an `empty map` or a map full of `null` value. Then we will get `null`.
+  ```sql
+  select map_roulette(map(null, null, null, null)); -- NULL
+  select map_roulette(map()); -- NULL
+  ```
+  An occasional `null` weight will be treated as average weight.
+  ```sql
+  select map_roulette(map(1, 0.5, 'Wang', null)); -- 50% Wang, 50% 1
+  select map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)); -- 1/3 Wang, 1/3 1, 1/3 Zhang
+  ```
+  All the weight is zero will return `null`.
+  ```sql
+  select map_roulette(map(1, 0)); -- NULL
+  select map_roulette(map(1, 0, '5', 0)); -- NULL
+  ```
+  This udf isn't support non-numeric weight or negative weight.
+  ```sql
+  select map_roulette(map('Wong', 'A string', 'Zhao', 2)); 
+  --Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.metadata.HiveException: Error evaluating map_roulette([map('Wong':'A string','Zhao':2)])
+  select map_roulette(map('Wong', 3, 'Zhao', -2));
+  -- Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.exec.UDFArgumentException: -2 < 0
+  ```
+   
 # MapReduce
 
 - `distcache_gets(filepath, key, default_value [, parseKey])` - Returns map&lt;key_type, value_type&gt;|value_type
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index e6f7c0b..4faaeed 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -507,6 +507,9 @@ create temporary function map_get as 'hivemall.tools.map.MapGetUDF';
 drop temporary function if exists map_key_values;
 create temporary function map_key_values as 'hivemall.tools.map.MapKeyValuesUDF';
 
+drop temporary function if exists map_roulette;
+create temporary function map_roulette as 'hivemall.tools.map.MapRouletteUDF';
+
 ---------------------
 -- list functions --
 ---------------------