You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hivemall.apache.org by my...@apache.org on 2019/02/21 06:55:44 UTC
[incubator-hivemall] branch master updated: [HIVEMALL-238] Fixed
from_json UDF to support top-level Map object
This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 525418f [HIVEMALL-238] Fixed from_json UDF to support top-level Map object
525418f is described below
commit 525418f4edd4e3c5df273c0d902cb2d8035e3c9b
Author: Makoto Yui <my...@apache.org>
AuthorDate: Thu Feb 21 15:55:39 2019 +0900
[HIVEMALL-238] Fixed from_json UDF to support top-level Map object
## What changes were proposed in this pull request?
Fixed from_json UDF to support top-level Map object
## What type of PR is it?
Bug Fix
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-238
## How was this patch tested?
unit tests, manual tests
## How to use this feature?
```sql
select
from_json(to_json(map('one',1,'two',2)), 'map<string,int>')
```
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <my...@apache.org>
Closes #184 from myui/HIVEMALL-238.
---
.../main/java/hivemall/tools/json/FromJsonUDF.java | 8 +-
.../utils/hadoop/HiveJsonStructReader.java | 415 +++++++++++++++++++++
.../java/hivemall/utils/hadoop/JsonSerdeUtils.java | 13 +-
.../hivemall/utils/hadoop/JsonSerdeUtilsTest.java | 25 +-
docs/gitbook/misc/funcs.md | 2 +-
docs/gitbook/misc/generic_funcs.md | 160 ++++++--
6 files changed, 583 insertions(+), 40 deletions(-)
diff --git a/core/src/main/java/hivemall/tools/json/FromJsonUDF.java b/core/src/main/java/hivemall/tools/json/FromJsonUDF.java
index 2a17f0c..fb4a615 100644
--- a/core/src/main/java/hivemall/tools/json/FromJsonUDF.java
+++ b/core/src/main/java/hivemall/tools/json/FromJsonUDF.java
@@ -49,6 +49,7 @@ import org.apache.hive.hcatalog.data.HCatRecordObjectInspectorFactory;
value = "_FUNC_(string jsonString, const string returnTypes [, const array<string>|const string columnNames])"
+ " - Return Hive object.",
extended = "SELECT\n" +
+ " from_json(to_json(map('one',1,'two',2)), 'map<string,int>'),\n" +
" from_json(\n" +
" '{ \"person\" : { \"name\" : \"makoto\" , \"age\" : 37 } }',\n" +
" 'struct<name:string,age:int>', \n" +
@@ -79,6 +80,7 @@ import org.apache.hive.hcatalog.data.HCatRecordObjectInspectorFactory;
" ),'array<struct<city:string>>');\n"
+ "```\n\n" +
"```\n" +
+ " {\"one\":1,\"two\":2}\n" +
" {\"name\":\"makoto\",\"age\":37}\n" +
" [0.1,1.1,2.2]\n" +
" [{\"country\":\"japan\",\"city\":\"tokyo\"},{\"country\":\"japan\",\"city\":\"osaka\"}]\n" +
@@ -171,7 +173,11 @@ public final class FromJsonUDF extends GenericUDF {
final Object result;
try {
- result = JsonSerdeUtils.deserialize(jsonString, columnNames, columnTypes);
+ if (columnNames == null && columnTypes != null && columnTypes.size() == 1) {
+ result = JsonSerdeUtils.deserialize(jsonString, columnTypes.get(0));
+ } else {
+ result = JsonSerdeUtils.deserialize(jsonString, columnNames, columnTypes);
+ }
} catch (Throwable e) {
throw new HiveException("Failed to deserialize Json: \n" + jsonString.toString() + '\n'
+ ExceptionUtils.prettyPrintStackTrace(e),
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveJsonStructReader.java b/core/src/main/java/hivemall/utils/hadoop/HiveJsonStructReader.java
new file mode 100644
index 0000000..8e57890
--- /dev/null
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveJsonStructReader.java
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// This file codes borrowed from
+// - https://github.com/apache/hive/blob/master/serde/src/java/org/apache/hadoop/hive/serde2/json/HiveJsonStructReader.java
+package hivemall.utils.hadoop;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.CharacterCodingException;
+import java.sql.Date;
+import java.sql.Timestamp;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.common.type.HiveChar;
+import org.apache.hadoop.hive.common.type.HiveDecimal;
+import org.apache.hadoop.hive.common.type.HiveVarchar;
+import org.apache.hadoop.hive.serde2.SerDeException;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.typeinfo.BaseCharTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
+import org.apache.hadoop.io.Text;
+import org.codehaus.jackson.JsonFactory;
+import org.codehaus.jackson.JsonParseException;
+import org.codehaus.jackson.JsonParser;
+import org.codehaus.jackson.JsonToken;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class HiveJsonStructReader {
+ private static final Logger LOG = LoggerFactory.getLogger(HiveJsonStructReader.class);
+
+ private ObjectInspector oi;
+ private JsonFactory factory;
+
+ private final Set<String> reportedUnknownFieldNames = new HashSet<>();
+
+ private boolean ignoreUnknownFields;
+ private boolean hiveColIndexParsing;
+ private boolean writeablePrimitives;
+
+ public HiveJsonStructReader(TypeInfo t) {
+ oi = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(t);
+ factory = new JsonFactory();
+ }
+
+ public Object parseStruct(String text) throws JsonParseException, IOException, SerDeException {
+ JsonParser parser = factory.createJsonParser(text);
+ return parseInternal(parser);
+ }
+
+ public Object parseStruct(InputStream is)
+ throws JsonParseException, IOException, SerDeException {
+ JsonParser parser = factory.createJsonParser(is);
+ return parseInternal(parser);
+ }
+
+ private Object parseInternal(JsonParser parser) throws SerDeException {
+ try {
+ parser.nextToken();
+ Object res = parseDispatcher(parser, oi);
+ return res;
+ } catch (Exception e) {
+ String locationStr = parser.getCurrentLocation().getLineNr() + ","
+ + parser.getCurrentLocation().getColumnNr();
+ throw new SerDeException("at[" + locationStr + "]: " + e.getMessage(), e);
+ }
+ }
+
+ private Object parseDispatcher(JsonParser parser, ObjectInspector oi)
+ throws JsonParseException, IOException, SerDeException {
+
+ switch (oi.getCategory()) {
+ case PRIMITIVE:
+ return parsePrimitive(parser, (PrimitiveObjectInspector) oi);
+ case LIST:
+ return parseList(parser, (ListObjectInspector) oi);
+ case STRUCT:
+ return parseStruct(parser, (StructObjectInspector) oi);
+ case MAP:
+ return parseMap(parser, (MapObjectInspector) oi);
+ default:
+ throw new SerDeException("parsing of: " + oi.getCategory() + " is not handled");
+ }
+ }
+
+ private Object parseMap(JsonParser parser, MapObjectInspector oi)
+ throws IOException, SerDeException {
+
+ if (parser.getCurrentToken() == JsonToken.VALUE_NULL) {
+ parser.nextToken();
+ return null;
+ }
+
+ Map<Object, Object> ret = new LinkedHashMap<>();
+
+ if (parser.getCurrentToken() != JsonToken.START_OBJECT) {
+ throw new SerDeException("struct expected");
+ }
+
+ if (!(oi.getMapKeyObjectInspector() instanceof PrimitiveObjectInspector)) {
+ throw new SerDeException("map key must be a primitive");
+ }
+ PrimitiveObjectInspector keyOI = (PrimitiveObjectInspector) oi.getMapKeyObjectInspector();
+ ObjectInspector valOI = oi.getMapValueObjectInspector();
+
+ JsonToken currentToken = parser.nextToken();
+ while (currentToken != null && currentToken != JsonToken.END_OBJECT) {
+
+ if (currentToken != JsonToken.FIELD_NAME) {
+ throw new SerDeException("unexpected token: " + currentToken);
+ }
+
+ Object key = parseMapKey(parser, keyOI);
+ Object val = parseDispatcher(parser, valOI);
+ ret.put(key, val);
+
+ currentToken = parser.getCurrentToken();
+ }
+ if (currentToken != null) {
+ parser.nextToken();
+ }
+ return ret;
+
+ }
+
+ private Object parseStruct(JsonParser parser, StructObjectInspector oi)
+ throws JsonParseException, IOException, SerDeException {
+
+ Object[] ret = new Object[oi.getAllStructFieldRefs().size()];
+
+ if (parser.getCurrentToken() == JsonToken.VALUE_NULL) {
+ parser.nextToken();
+ return null;
+ }
+ if (parser.getCurrentToken() != JsonToken.START_OBJECT) {
+ throw new SerDeException("struct expected");
+ }
+ JsonToken currentToken = parser.nextToken();
+ while (currentToken != null && currentToken != JsonToken.END_OBJECT) {
+
+ switch (currentToken) {
+ case FIELD_NAME:
+ String name = parser.getCurrentName();
+ try {
+ StructField field = null;
+ try {
+ field = getStructField(oi, name);
+ } catch (RuntimeException e) {
+ if (ignoreUnknownFields) {
+ if (!reportedUnknownFieldNames.contains(name)) {
+ LOG.warn("ignoring field:" + name);
+ reportedUnknownFieldNames.add(name);
+ }
+ parser.nextToken();
+ skipValue(parser);
+ break;
+ }
+ }
+ if (field == null) {
+ throw new SerDeException("undeclared field");
+ }
+ parser.nextToken();
+ int fieldId = getStructFieldIndex(oi, field);
+ ret[fieldId] = parseDispatcher(parser, field.getFieldObjectInspector());
+ } catch (Exception e) {
+ throw new SerDeException("struct field " + name + ": " + e.getMessage(), e);
+ }
+ break;
+ default:
+ throw new SerDeException("unexpected token: " + currentToken);
+ }
+ currentToken = parser.getCurrentToken();
+ }
+ if (currentToken != null) {
+ parser.nextToken();
+ }
+ return ret;
+ }
+
+ private static int getStructFieldIndex(@Nonnull StructObjectInspector oi,
+ @Nonnull StructField field) {
+ final List<? extends StructField> fields = oi.getAllStructFieldRefs();
+ for (int i = 0, size = fields.size(); i < size; i++) {
+ StructField f = fields.get(i);
+ if (field.equals(f)) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ private StructField getStructField(StructObjectInspector oi, String name) {
+ if (hiveColIndexParsing) {
+ int colIndex = getColIndex(name);
+ if (colIndex >= 0) {
+ return oi.getAllStructFieldRefs().get(colIndex);
+ }
+ }
+ // FIXME: linear scan inside the below method...get a map here or something..
+ return oi.getStructFieldRef(name);
+ }
+
+ Pattern internalPattern = Pattern.compile("^_col([0-9]+)$");
+
+ private int getColIndex(String internalName) {
+ // The above line should have been all the implementation that
+ // we need, but due to a bug in that impl which recognizes
+ // only single-digit columns, we need another impl here.
+ Matcher m = internalPattern.matcher(internalName);
+ if (!m.matches()) {
+ return -1;
+ } else {
+ return Integer.parseInt(m.group(1));
+ }
+ }
+
+ private static void skipValue(JsonParser parser) throws JsonParseException, IOException {
+
+ int array = 0;
+ int object = 0;
+ do {
+ JsonToken currentToken = parser.getCurrentToken();
+ if (currentToken == JsonToken.START_ARRAY) {
+ array++;
+ }
+ if (currentToken == JsonToken.END_ARRAY) {
+ array--;
+ }
+ if (currentToken == JsonToken.START_OBJECT) {
+ object++;
+ }
+ if (currentToken == JsonToken.END_OBJECT) {
+ object--;
+ }
+
+ parser.nextToken();
+
+ } while (array > 0 || object > 0);
+
+ }
+
+ private Object parseList(JsonParser parser, ListObjectInspector oi)
+ throws JsonParseException, IOException, SerDeException {
+ List<Object> ret = new ArrayList<>();
+
+ if (parser.getCurrentToken() == JsonToken.VALUE_NULL) {
+ parser.nextToken();
+ return null;
+ }
+
+ if (parser.getCurrentToken() != JsonToken.START_ARRAY) {
+ throw new SerDeException("array expected");
+ }
+ ObjectInspector eOI = oi.getListElementObjectInspector();
+ JsonToken currentToken = parser.nextToken();
+ try {
+ while (currentToken != null && currentToken != JsonToken.END_ARRAY) {
+ ret.add(parseDispatcher(parser, eOI));
+ currentToken = parser.getCurrentToken();
+ }
+ } catch (Exception e) {
+ throw new SerDeException("array: " + e.getMessage(), e);
+ }
+
+ currentToken = parser.nextToken();
+
+ return ret;
+ }
+
+ private Object parsePrimitive(JsonParser parser, PrimitiveObjectInspector oi)
+ throws SerDeException, IOException {
+ JsonToken currentToken = parser.getCurrentToken();
+ if (currentToken == null) {
+ return null;
+ }
+ try {
+ switch (parser.getCurrentToken()) {
+ case VALUE_FALSE:
+ case VALUE_TRUE:
+ case VALUE_NUMBER_INT:
+ case VALUE_NUMBER_FLOAT:
+ case VALUE_STRING:
+ return getObjectOfCorrespondingPrimitiveType(parser.getText(), oi);
+ case VALUE_NULL:
+ return null;
+ default:
+ throw new SerDeException("unexpected token type: " + currentToken);
+ }
+ } finally {
+ parser.nextToken();
+
+ }
+ }
+
+ private Object getObjectOfCorrespondingPrimitiveType(String s, PrimitiveObjectInspector oi)
+ throws IOException {
+ PrimitiveTypeInfo typeInfo = oi.getTypeInfo();
+ if (writeablePrimitives) {
+ Converter c = ObjectInspectorConverters.getConverter(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector, oi);
+ return c.convert(s);
+ }
+
+ switch (typeInfo.getPrimitiveCategory()) {
+ case INT:
+ return Integer.valueOf(s);
+ case BYTE:
+ return Byte.valueOf(s);
+ case SHORT:
+ return Short.valueOf(s);
+ case LONG:
+ return Long.valueOf(s);
+ case BOOLEAN:
+ return (s.equalsIgnoreCase("true"));
+ case FLOAT:
+ return Float.valueOf(s);
+ case DOUBLE:
+ return Double.valueOf(s);
+ case STRING:
+ return s;
+ case BINARY:
+ try {
+ String t = Text.decode(s.getBytes(), 0, s.getBytes().length);
+ return t.getBytes();
+ } catch (CharacterCodingException e) {
+ LOG.warn("Error generating json binary type from object.", e);
+ return null;
+ }
+ case DATE:
+ return Date.valueOf(s);
+ case TIMESTAMP:
+ return Timestamp.valueOf(s);
+ case DECIMAL:
+ return HiveDecimal.create(s);
+ case VARCHAR:
+ return new HiveVarchar(s, ((BaseCharTypeInfo) typeInfo).getLength());
+ case CHAR:
+ return new HiveChar(s, ((BaseCharTypeInfo) typeInfo).getLength());
+ }
+ throw new IOException(
+ "Could not convert from string to map type " + typeInfo.getTypeName());
+ }
+
+ private Object parseMapKey(JsonParser parser, PrimitiveObjectInspector oi)
+ throws SerDeException, IOException {
+ JsonToken currentToken = parser.getCurrentToken();
+ if (currentToken == null) {
+ return null;
+ }
+ try {
+ switch (parser.getCurrentToken()) {
+ case FIELD_NAME:
+ return getObjectOfCorrespondingPrimitiveType(parser.getText(), oi);
+ case VALUE_NULL:
+ return null;
+ default:
+ throw new SerDeException("unexpected token type: " + currentToken);
+ }
+ } finally {
+ parser.nextToken();
+
+ }
+ }
+
+ public void setIgnoreUnknownFields(boolean b) {
+ ignoreUnknownFields = b;
+ }
+
+ public void enableHiveColIndexParsing(boolean b) {
+ hiveColIndexParsing = b;
+ }
+
+ public void setWritablesUsage(boolean b) {
+ writeablePrimitives = b;
+ }
+
+ public ObjectInspector getObjectInspector() {
+ return oi;
+ }
+}
diff --git a/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
index 562e9a4..f988184 100644
--- a/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
@@ -367,10 +367,19 @@ public final class JsonSerdeUtils {
/**
* Deserialize Json array or Json primitives.
*/
+ @SuppressWarnings("unchecked")
@Nonnull
- public static <T> T deserialize(@Nonnull final Text t, @Nonnull TypeInfo columnTypes)
+ public static <T> T deserialize(@Nonnull final Text t, @Nonnull TypeInfo columnType)
throws SerDeException {
- return deserialize(t, null, Arrays.asList(columnTypes));
+ final HiveJsonStructReader reader = new HiveJsonStructReader(columnType);
+ reader.setIgnoreUnknownFields(true);
+ final Object result;
+ try {
+ result = reader.parseStruct(new FastByteArrayInputStream(t.getBytes(), t.getLength()));
+ } catch (IOException e) {
+ throw new SerDeException(e);
+ }
+ return (T) result;
}
@SuppressWarnings("unchecked")
diff --git a/core/src/test/java/hivemall/utils/hadoop/JsonSerdeUtilsTest.java b/core/src/test/java/hivemall/utils/hadoop/JsonSerdeUtilsTest.java
index a3e81d2..9107fb5 100644
--- a/core/src/test/java/hivemall/utils/hadoop/JsonSerdeUtilsTest.java
+++ b/core/src/test/java/hivemall/utils/hadoop/JsonSerdeUtilsTest.java
@@ -307,6 +307,30 @@ public class JsonSerdeUtilsTest {
Text serialized1 = JsonSerdeUtils.serialize(deserialized1,
HCatRecordObjectInspectorFactory.getStandardObjectInspectorFromTypeInfo(type1));
Assert.assertEquals(json1, serialized1);
+
+ List<Map<String, Integer>> expected2 = Arrays.<Map<String, Integer>>asList(
+ ImmutableMap.of("one", 1, "two", 2), ImmutableMap.of("three", 3));
+ Text json2 = new Text("[{\"one\":1,\"two\":2},{\"three\":3}]");
+ TypeInfo type2 = TypeInfoUtils.getTypeInfoFromTypeString("array<map<string,int>>");
+
+ List<Object> deserialized2 = JsonSerdeUtils.deserialize(json2, type2);
+ assertRecordEquals(expected2, deserialized2);
+ Text serialized2 = JsonSerdeUtils.serialize(deserialized2,
+ HCatRecordObjectInspectorFactory.getStandardObjectInspectorFromTypeInfo(type2));
+ Assert.assertEquals(json2, serialized2);
+ }
+
+ @Test
+ public void testTopLevelMap() throws Exception {
+ Map<String, Integer> expected1 = ImmutableMap.of("one", 1, "two", 2);
+ Text json1 = new Text("{\"one\":1,\"two\":2}");
+ TypeInfo type1 = TypeInfoUtils.getTypeInfoFromTypeString("map<string,int>");
+
+ Map<String, Integer> deserialized1 = JsonSerdeUtils.deserialize(json1, type1);
+ Assert.assertEquals(expected1, deserialized1);
+ Text serialized1 = JsonSerdeUtils.serialize(deserialized1,
+ HCatRecordObjectInspectorFactory.getStandardObjectInspectorFromTypeInfo(type1));
+ Assert.assertEquals(json1, serialized1);
}
@Test
@@ -331,7 +355,6 @@ public class JsonSerdeUtilsTest {
Assert.assertEquals(json2, serialized2);
}
-
private static void assertRecordEquals(@Nonnull final List<?> first,
@Nonnull final List<?> second) {
int mySz = first.size();
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index a904e71..e623572 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -494,7 +494,7 @@ This page describes a list of Hivemall functions. See also a [list of generic Hi
- `train_randomforest_classifier(array<double|string> features, int label [, const string options, const array<double> classWeights])`- Returns a relation consists of <string model_id, double model_weight, string model, array<double> var_importance, int oob_errors, int oob_tests>
-- `train_randomforest_regression(array<double|string> features, double target [, string options])` - Returns a relation consists of <int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>
+- `train_randomforest_regressor(array<double|string> features, double target [, string options])` - Returns a relation consists of <int model_id, int model_type, string model, array<double> var_importance, double oob_errors, int oob_tests>
- `guess_attribute_types(ANY, ...)` - Returns attribute types
```sql
diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md
index b282731..4f53f4d 100644
--- a/docs/gitbook/misc/generic_funcs.md
+++ b/docs/gitbook/misc/generic_funcs.md
@@ -33,6 +33,19 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
```
- `array_avg(array<number>)` - Returns an array<double> in which each element is the mean of a set of numbers
+ ```sql
+ WITH input as (
+ select array(1.0, 2.0, 3.0) as nums
+ UNION ALL
+ select array(2.0, 3.0, 4.0) as nums
+ )
+ select
+ array_avg(nums)
+ from
+ input;
+
+ ["1.5","2.5","3.5"]
+ ```
- `array_concat(array<ANY> x1, array<ANY> x2, ..)` - Returns a concatenated array
```sql
@@ -104,6 +117,19 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
```
- `array_sum(array<number>)` - Returns an array<double> in which each element is summed up
+ ```sql
+ WITH input as (
+ select array(1.0, 2.0, 3.0) as nums
+ UNION ALL
+ select array(2.0, 3.0, 4.0) as nums
+ )
+ select
+ array_sum(nums)
+ from
+ input;
+
+ ["3.0","5.0","7.0"]
+ ```
- `array_to_str(array arr [, string sep=','])` - Convert array to string using a sperator
```sql
@@ -216,6 +242,11 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
```
- `to_string_array(array<ANY>)` - Returns an array of strings
+ ```sql
+ select to_string_array(array(1.0,2.0,3.0));
+
+ ["1.0","2.0","3.0"]
+ ```
- `to_ordered_list(PRIMITIVE value [, PRIMITIVE key, const string options])` - Return list of values sorted by value itself or specific key
```sql
@@ -303,6 +334,7 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
- `from_json(string jsonString, const string returnTypes [, const array<string>|const string columnNames])` - Return Hive object.
```sql
SELECT
+ from_json(to_json(map('one',1,'two',2)), 'map<string,int>'),
from_json(
'{ "person" : { "name" : "makoto" , "age" : 37 } }',
'struct<name:string,age:int>',
@@ -334,6 +366,7 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
```
```
+ {"one":1,"two":2}
{"name":"makoto","age":37}
[0.1,1.1,2.2]
[{"country":"japan","city":"tokyo"},{"country":"japan","city":"osaka"}]
@@ -426,38 +459,41 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
{1:"one"}
```
-- `map_get_sum(map<int,float> src, array<int> keys)` - Returns sum of values that are retrieved by keys
-
-- `map_include_keys(Map<K,V> map, array<K> filteringKeys)` - Returns the filtered entries of a map having specified keys
+- `map_get(MAP<K> a, K n)` - Returns the value corresponding to the key in the map.
```sql
- SELECT map_include_keys(map(1,'one',2,'two',3,'three'),array(2,3));
- {2:"two",3:"three"}
- ```
+ Note this is a workaround for a Hive issue that non-constant expression for map indexes not supported.
+ See https://issues.apache.org/jira/browse/HIVE-1955
-- `map_get(Map<K> a, K n)` - Returns the value corresponding to the key in the map
- ```sql
WITH tmp as (
SELECT "one" as key
UNION ALL
SELECT "two" as key
)
- SELECT map_index(map("one",1,"two",2),key)
+ SELECT map_get(map("one",1,"two",2),key)
FROM tmp;
- 1
- 2
+ > 1
+ > 2
+ ```
+
+- `map_get_sum(map<int,float> src, array<int> keys)` - Returns sum of values that are retrieved by keys
+
+- `map_include_keys(Map<K,V> map, array<K> filteringKeys)` - Returns the filtered entries of a map having specified keys
+ ```sql
+ SELECT map_include_keys(map(1,'one',2,'two',3,'three'),array(2,3));
+ {2:"two",3:"three"}
```
-- `map_key_values(map)` - Returns a array of key-value pairs.
+- `map_key_values(MAP<K, V> map)` - Returns a array of key-value pairs in array<named_struct<key,value>>
```sql
SELECT map_key_values(map("one",1,"two",2));
- [{"key":"one","value":1},{"key":"two","value":2}]
+ > [{"key":"one","value":1},{"key":"two","value":2}]
```
- `map_tail_n(map SRC, int N)` - Returns the last N elements from a sorted array of SRC
-- `merge_maps(x)` - Returns a map which contains the union of an aggregation of maps. Note that an existing value of a key can be replaced with the other duplicate key entry.
+- `merge_maps(Map x)` - Returns a map which contains the union of an aggregation of maps. Note that an existing value of a key can be replaced with the other duplicate key entry.
```sql
SELECT
merge_maps(m)
@@ -469,6 +505,17 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
```
- `to_map(key, value)` - Convert two aggregated columns into a key-value map
+ ```sql
+ WITH input as (
+ select 'aaa' as key, 111 as value
+ UNION all
+ select 'bbb' as key, 222 as value
+ )
+ select to_map(key, value)
+ from input;
+
+ > {"bbb":222,"aaa":111}
+ ```
- `to_ordered_map(key, value [, const int k|const boolean reverseOrder=false])` - Convert two aggregated columns into an ordered key-value map
```sql
@@ -514,21 +561,81 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
- `infinity()` - Returns the constant representing positive infinity.
-- `is_finite(x)` - Determine if x is infinite.
+- `is_finite(x)` - Determine if x is finite.
+ ```sql
+ SELECT is_finite(333), is_finite(infinity());
+ > true false
+ ```
- `is_infinite(x)` - Determine if x is infinite.
- `is_nan(x)` - Determine if x is not-a-number.
-- `l2_norm(double xi)` - Return L2 norm of a vector which has the given values in each dimension
+- `l2_norm(double x)` - Return a L2 norm of the given input x.
+ ```sql
+ WITH input as (
+ select generate_series(1,3) as v
+ )
+ select l2_norm(v) as l2norm
+ from input;
+ > 3.7416573867739413 = sqrt(1^2+2^2+3^2))
+ ```
- `nan()` - Returns the constant representing not-a-number.
+ ```sql
+ SELECT nan(), is_nan(nan());
+ > NaN true
+ ```
- `sigmoid(x)` - Returns 1.0 / (1.0 + exp(-x))
+ ```sql
+ WITH input as (
+ SELECT 3.0 as x
+ UNION ALL
+ SELECT -3.0 as x
+ )
+ select
+ 1.0 / (1.0 + exp(-x)),
+ sigmoid(x)
+ from
+ input;
+ > 0.04742587317756678 0.04742587357759476
+ > 0.9525741268224334 0.9525741338729858
+ ```
+
+# Vector/Matrix
+
+- `transpose_and_dot(array<number> X, array<number> Y)` - Returns dot(X.T, Y) as array<array<double>>, shape = (X.#cols, Y.#cols)
+ ```sql
+ WITH input as (
+ select array(1.0, 2.0, 3.0, 4.0) as x, array(1, 2) as y
+ UNION ALL
+ select array(2.0, 3.0, 4.0, 5.0) as x, array(1, 2) as y
+ )
+ select
+ transpose_and_dot(x, y) as xy,
+ transpose_and_dot(y, x) as yx
+ from
+ input;
+
+ > [["3.0","6.0"],["5.0","10.0"],["7.0","14.0"],["9.0","18.0"]] [["3.0","5.0","7.0","9.0"],["6.0","10.0","14.0","18.0"]]
+
+ ```
+
+- `vector_add(array<NUMBER> x, array<NUMBER> y)` - Perform vector ADD operation.
+ ```sql
+ SELECT vector_add(array(1.0,2.0,3.0), array(2, 3, 4));
+ [3.0,5.0,7.0]
+ ```
-# Matrix
+- `vector_dot(array<NUMBER> x, array<NUMBER> y)` - Performs vector dot product.
+ ```sql
+ SELECT vector_dot(array(1.0,2.0,3.0),array(2.0,3.0,4.0));
+ 20
-- `transpose_and_dot(array<number> matrix0_row, array<number> matrix1_row)` - Returns dot(matrix0.T, matrix1) as array<array<double>>, shape = (matrix0.#cols, matrix1.#cols)
+ SELECT vector_dot(array(1.0,2.0,3.0),2);
+ [2.0,4.0,6.0]
+ ```
# Sanity Checks
@@ -600,23 +707,6 @@ This page describes a list of useful Hivemall generic functions. See also a [lis
6.0
```
-# Vector
-
-- `vector_add(array<NUMBER> x, array<NUMBER> y)` - Perform vector ADD operation.
- ```sql
- SELECT vector_add(array(1.0,2.0,3.0), array(2, 3, 4));
- [3.0,5.0,7.0]
- ```
-
-- `vector_dot(array<NUMBER> x, array<NUMBER> y)` - Performs vector dot product.
- ```sql
- SELECT vector_dot(array(1.0,2.0,3.0),array(2.0,3.0,4.0));
- 20
-
- SELECT vector_dot(array(1.0,2.0,3.0),2);
- [2.0,4.0,6.0]
- ```
-
# Others
- `convert_label(const int|const float)` - Convert from -1|1 to 0.0f|1.0f, or from 0.0f|1.0f to -1|1