You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@hivemall.apache.org by nzw0301 <gi...@git.apache.org> on 2017/09/21 02:48:50 UTC
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
GitHub user nzw0301 opened a pull request:
https://github.com/apache/incubator-hivemall/pull/116
[WIP][HIVEMALL-118] word2vec
## What changes were proposed in this pull request?
Add new algorithm: skip-gram with negative sampling (a.k.a word2vec)
## What type of PR is it?
Improvement
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-118
## How was this patch tested?
manual tests on EMR
## How to use this feature?
please see `word2vec.md`
## Checklist
- [x] Did you apply source code formatter, i.e., `mvn formatter:format`, for your commit?
You can merge this pull request into a Git repository by running:
$ git pull https://github.com/nzw0301/incubator-hivemall skipgram
Alternatively you can review and apply these changes as the patch at:
https://github.com/apache/incubator-hivemall/pull/116.patch
To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:
This closes #116
----
commit f19186fe8eff3de0400cc318c8c876fc69dbc766
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-13T12:50:35Z
Init docs for word2vec
commit e9a76093efd1803dd63a11697abc68e7ae78cbb8
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-13T12:54:33Z
Init Alias Table builder
commit b6883b00e8f543130e16caa2ad9010cfcd0e6cd9
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-13T13:02:10Z
Fix typo
commit 3ca761956638c95264a37edfb49133f35c37cdd5
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-14T02:15:51Z
Separate calias table function
commit 2588e732b1194d800005c6ac28453adf79d89278
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-14T12:26:49Z
Create skip-gram UDTF
commit 33900380a337821f2da280d0c7c9c10f2fe14565
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-15T03:35:36Z
Use float to save memory
commit a7394a04bc699e105febf86a6b9712f8daa8ff92
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-15T07:04:44Z
Update query example in docs
commit 50d5ffcffff503f1522577b7789e7745650a1ece
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-15T07:05:21Z
Update forwarding
commit b701919825b3ed8909659b77236d6522f6d51be6
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-18T12:03:55Z
Init Word2vecFeatureUDTF
commit 7f69ef66ef2849e57d06601fb9098b6bf2c2bb21
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-19T08:53:26Z
Implement skip-gramfeature UDTF
commit 48c1929b6345d144059eec44636c0194eebc5281
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-19T14:01:46Z
Update skipgram
commit 7f4abde3760cf17455d99eb6c5a7cbe8c3dc5e3c
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-19T16:05:15Z
Refactor Skipgram
commit 2e429d60ff2ae473a64cb890e67c9d8b9a7b2830
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-20T05:26:25Z
Update for query change
commit 7014e8552f81d59ab35c95d8fcf54c56c24ba2c9
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date: 2017-09-20T08:54:51Z
Remove discard table
----
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13366956/badge)](https://coveralls.io/builds/13366956)
Coverage decreased (-0.4%) to 40.499% when pulling **c224912606f0dfd4d7d53acb0eac3fecc016b335 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13385853/badge)](https://coveralls.io/builds/13385853)
Coverage decreased (-0.5%) to 40.468% when pulling **a3ccaa8a38ecff49d23e94a8cc7db7f895181e06 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544983
--- Diff: docs/gitbook/embedding/word2vec.md ---
@@ -0,0 +1,399 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Word Embedding is a powerful tool for many tasks,
+e.g. finding similar words,
+feature vectors for supervised machine learning task and word analogy,
+such as `king - man + woman =~ queen`.
+In word embedding,
+each word represents a low dimension and dense vector.
+**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
+
+The papers introduce the method are as follows:
+
+- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
+](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
+- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
+
+Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
+Hivemall enables you to train your sequence data such as,
+but not limited to, documents based on word2vec.
+This article gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.? or later.
+
+# Prepare document data
+
+Assume that you already have `docs` table which contains many documents as string format with unique index:
+
+```sql
+select * FROM docs;
+```
+
+| docId | doc |
+|:----: |:----|
+| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
+| ... | ... |
+
+First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
+
+```sql
+drop table docs_words;
+create table docs_words as
+ select
+ docid,
+ tokenize(doc, true) as words
+ FROM
+ docs
+;
+```
+
+This table shows tokenized document.
+
+| docId | doc |
+|:----: |:----|
+| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
+| ... | ... |
+
+Then, you count frequency up per word and remove low frequency words from the vocabulary.
+To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
+
+```sql
+set hivevar:mincount=5;
+
+drop table freq;
+create table freq as
+select
+ row_number() over () - 1 as wordid,
+ word,
+ freq
+from (
+ select
+ word,
+ COUNT(*) as freq
+ from
+ docs_words
+ LATERAL VIEW explode(words) lTable as word
+ group by
+ word
+) t
+where freq >= ${mincount}
+;
+```
+
+Hivemall's word2vec supports two type words; string and int.
+String type tends to use huge memory during training.
+On the other hand, int type tends to use less memory.
+If you train on small dataset, we recommend using string type,
+because memory usage can be ignored and HiveQL is more simple.
+If you train on large dataset, we recommend using int type,
+because it saves memory during training.
+
+# Create sub-sampling table
+
+Sub-sampling table is stored a sub-sampling probability per word.
+
+The sub-sampling probability of word $$w_i$$ is computed by the following equation:
+
+$$
+\begin{aligned}
+f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
+\end{aligned}
+$$
+
+During word2vec training,
+not sub-sampled words are ignored.
+It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
+The smaller `sample` value set,
+the fewer words are used during training.
+
+```sql
+set hivevar:sample=1e-4;
+
+drop table subsampling_table;
+create table subsampling_table as
+with stats as (
+ select
+ sum(freq) as numTrainWords
+ FROM
+ freq
+)
+select
+ l.wordid,
+ l.word,
+ sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
+from
+ freq l
+cross join
+ stats r
+;
+```
+
+```sql
+select * FROM subsampling_table order by p;
+```
+
+| wordid | word | p |
+|:----: | :----: |:----:|
+| 48645 | the | 0.04013665|
+| 11245 | of | 0.052463654|
+| 16368 | and | 0.06555538|
+| 61938 | 00 | 0.068162076|
+| 19977 | in | 0.071441144|
+| 83599 | 0 | 0.07528994|
+| 95017 | a | 0.07559573|
+| 1225 | to | 0.07953133|
+| 37062 | 0000 | 0.08779001|
+| 58246 | is | 0.09049763|
+| ... | ... |... |
+
+The first row shows that 4% of `the` are used in the documents during training.
+
+# Delete low frequency words and high frequency words from `docs_words`
+
+To reduce useless words from corpus,
+low frequency words and high frequency words are deleted.
+And, to avoid loading long document on memory, a document is split into some sub-documents.
+
+```sql
+set hivevar:maxlength=1500;
+SET hivevar:seed=31;
+
+drop table train_docs;
+create table train_docs as
+ with docs_exploded as (
+ select
+ docid,
+ word,
+ pos % ${maxlength} as pos,
+ pos div ${maxlength} as splitid,
+ rand(${seed}) as rnd
+ from
+ docs_words LATERAL VIEW posexplode(words) t as pos, word
+ )
+select
+ l.docid,
+ -- to_ordered_list(l.word, l.pos) as words
+ to_ordered_list(r2.wordid, l.pos) as words,
+from
+ docs_exploded l
+ LEFT SEMI join freq r on (l.word = r.word)
+ join subsampling_table r2 on (l.word = r2.word)
+where
+ r2.p > l.rnd
+group by
+ l.docid, l.splitid
+;
+```
+
+If you store string word in `train_docs` table,
+please replace `to_ordered_list(r2.wordid, l.pos) as words` with `to_ordered_list(l.word, l.pos) as words`.
+
+# Create negative sampling table
+
+Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
+Here, `negative_table` is used to store word sampling probability for negative sampling.
+`z` is a hyperparameter of noise distribution for negative sampling.
+During word2vec training,
--- End diff --
Line break is not needed. Line break after `,` is unreasonable (elsewhere as well).
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542877
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
--- End diff --
remove unnecessary blank line
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13440173/badge)](https://coveralls.io/builds/13440173)
Coverage decreased (-0.5%) to 40.392% when pulling **c34003804b5ef33bca6abe9113de6eb01b5e94c6 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546893
--- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class CBoWModel extends AbstractWord2VecModel {
+ protected CBoWModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
+ }
+
+ protected void trainOnDoc(@Nonnull final int[] doc) {
+ final int vecDim = dim;
+ final int numNegative = neg;
+ final PRNG _rnd = rnd;
+ final Int2FloatOpenHashTable _S = S;
+ final int[] _aliasWordId = aliasWordId;
+ float label, gradient;
+
+ // reuse instance
+ int windowSize, k, numContext, targetWord, inWord, positiveWord;
+
+ updateLearningRate();
+
+ int docLength = doc.length;
--- End diff --
`final int docLength`
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546522
--- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java ---
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Int2IntOpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.Queue;
+import java.util.ArrayDeque;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnull;
+
+public final class AliasTableBuilderUDTF extends GenericUDTF {
--- End diff --
Add Javadoc comment for the class referring papers.
```
<pre>
- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
</pre>
```
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141553131
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
--- End diff --
I see.
Does `longOpt` remain `learningRate` or remove this field?
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141551510
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
--- End diff --
There is no reason.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543945
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
--- End diff --
`PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI)` may return null.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13370459/badge)](https://coveralls.io/builds/13370459)
Coverage decreased (-0.8%) to 40.165% when pulling **83198617bcf82634d39d715e33499f99945f2ebb on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13386706/badge)](https://coveralls.io/builds/13386706)
Coverage decreased (-0.8%) to 40.14% when pulling **c7cba82a2eef2b5b32e67221a20c6cdb4570643a on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543095
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
--- End diff --
add `@Nonnegative` for each constructor argument and caller methods.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13412359/badge)](https://coveralls.io/builds/13412359)
Coverage decreased (-0.6%) to 40.343% when pulling **aede5ec0cf4d01780034c0d46c456486fecc1cb3 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13415337/badge)](https://coveralls.io/builds/13415337)
Coverage decreased (-0.6%) to 40.31% when pulling **4abdb8f3d2632baa9b3ead928bc8bb0283027e20 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545135
--- Diff: docs/gitbook/embedding/word2vec.md ---
@@ -0,0 +1,399 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Word Embedding is a powerful tool for many tasks,
+e.g. finding similar words,
+feature vectors for supervised machine learning task and word analogy,
+such as `king - man + woman =~ queen`.
+In word embedding,
+each word represents a low dimension and dense vector.
+**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
+
+The papers introduce the method are as follows:
+
+- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
+](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
+- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
+
+Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
+Hivemall enables you to train your sequence data such as,
+but not limited to, documents based on word2vec.
+This article gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.? or later.
+
+# Prepare document data
+
+Assume that you already have `docs` table which contains many documents as string format with unique index:
+
+```sql
+select * FROM docs;
+```
+
+| docId | doc |
+|:----: |:----|
+| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
+| ... | ... |
+
+First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
+
+```sql
+drop table docs_words;
+create table docs_words as
+ select
+ docid,
+ tokenize(doc, true) as words
+ FROM
+ docs
+;
+```
+
+This table shows tokenized document.
+
+| docId | doc |
+|:----: |:----|
+| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
+| ... | ... |
+
+Then, you count frequency up per word and remove low frequency words from the vocabulary.
+To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
+
+```sql
+set hivevar:mincount=5;
+
+drop table freq;
+create table freq as
+select
+ row_number() over () - 1 as wordid,
+ word,
+ freq
+from (
+ select
+ word,
+ COUNT(*) as freq
+ from
+ docs_words
+ LATERAL VIEW explode(words) lTable as word
+ group by
+ word
+) t
+where freq >= ${mincount}
+;
+```
+
+Hivemall's word2vec supports two type words; string and int.
+String type tends to use huge memory during training.
+On the other hand, int type tends to use less memory.
+If you train on small dataset, we recommend using string type,
+because memory usage can be ignored and HiveQL is more simple.
+If you train on large dataset, we recommend using int type,
+because it saves memory during training.
+
+# Create sub-sampling table
+
+Sub-sampling table is stored a sub-sampling probability per word.
+
+The sub-sampling probability of word $$w_i$$ is computed by the following equation:
+
+$$
+\begin{aligned}
+f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
+\end{aligned}
+$$
+
+During word2vec training,
+not sub-sampled words are ignored.
+It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
+The smaller `sample` value set,
+the fewer words are used during training.
+
+```sql
+set hivevar:sample=1e-4;
+
+drop table subsampling_table;
+create table subsampling_table as
+with stats as (
+ select
+ sum(freq) as numTrainWords
+ FROM
+ freq
+)
+select
+ l.wordid,
+ l.word,
+ sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
+from
+ freq l
+cross join
+ stats r
+;
+```
+
+```sql
+select * FROM subsampling_table order by p;
+```
+
+| wordid | word | p |
+|:----: | :----: |:----:|
+| 48645 | the | 0.04013665|
+| 11245 | of | 0.052463654|
+| 16368 | and | 0.06555538|
+| 61938 | 00 | 0.068162076|
+| 19977 | in | 0.071441144|
+| 83599 | 0 | 0.07528994|
+| 95017 | a | 0.07559573|
+| 1225 | to | 0.07953133|
+| 37062 | 0000 | 0.08779001|
+| 58246 | is | 0.09049763|
+| ... | ... |... |
+
+The first row shows that 4% of `the` are used in the documents during training.
+
+# Delete low frequency words and high frequency words from `docs_words`
+
+To reduce useless words from corpus,
+low frequency words and high frequency words are deleted.
+And, to avoid loading long document on memory, a document is split into some sub-documents.
+
+```sql
+set hivevar:maxlength=1500;
+SET hivevar:seed=31;
+
+drop table train_docs;
+create table train_docs as
+ with docs_exploded as (
+ select
+ docid,
+ word,
+ pos % ${maxlength} as pos,
+ pos div ${maxlength} as splitid,
+ rand(${seed}) as rnd
+ from
+ docs_words LATERAL VIEW posexplode(words) t as pos, word
+ )
+select
+ l.docid,
+ -- to_ordered_list(l.word, l.pos) as words
+ to_ordered_list(r2.wordid, l.pos) as words,
+from
+ docs_exploded l
+ LEFT SEMI join freq r on (l.word = r.word)
+ join subsampling_table r2 on (l.word = r2.word)
+where
+ r2.p > l.rnd
+group by
+ l.docid, l.splitid
+;
+```
+
+If you store string word in `train_docs` table,
+please replace `to_ordered_list(r2.wordid, l.pos) as words` with `to_ordered_list(l.word, l.pos) as words`.
+
+# Create negative sampling table
+
+Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
+Here, `negative_table` is used to store word sampling probability for negative sampling.
+`z` is a hyperparameter of noise distribution for negative sampling.
+During word2vec training,
+words sampled from this distribution are used for negative examples.
+Noise distribution is the unigram distribution raised to the 3/4rd power.
+
+$$
+\begin{aligned}
+p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
+\end{aligned}
+$$
+
+To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
+Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
+
+This method has proposed in papers below:
+
+- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
+- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
+
+```sql
+set hivevar:z=3/4;
+
+drop table negative_table;
+create table negative_table as
+select
+ collect_list(array(word, p, other)) as negative_table
+from (
+ select
+ alias_table(to_map(word, negative)) as (word, p, other)
+ from
+ (
+ select
+ word,
+ -- wordid as word,
+ pow(freq, ${z}) as negative
+ from
+ freq
+ ) t
+) t1
+;
+```
+
+`alias_table` function returns the records like following.
+
+| word | p | other |
+|:----: | :----: |:----:|
+| leopold | 0.6556492 | 0000 |
+| slep | 0.09060383 | leopold |
+| valentinian | 0.76077825 | belarusian |
+| slew | 0.90569097 | colin |
+| lucien | 0.86329675 | overland |
+| equitable | 0.7270946 | farms |
+| insurers | 0.2367955 | israel |
+| lucier | 0.14855136 | supplements |
+| lieve | 0.12075222 | separatist |
+| skyhawks | 0.14079945 | steamed |
+| ... | ... | ... |
+
+To sample negative word from this `negative_table`,
+
+1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
+2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
+3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
+
+Here, to use it in training function of word2vec,
+`alias_table`'s return records are stored into one list in the `negative_table`.
+
+# Train word2vec
+
+Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
+The default model is `"skipgram"`.
+
+> #### Note
+> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
+
+## Train Skip-Gram
+
+In skip-gram model,
+word vectors are trained to predict the nearby words.
+For example, given a sentence like a `"alice", "was", "beginning", "to"`,
+`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
+
+```sql
+select sum(size(words)) from train_docs;
+set hivevar:n=418953; -- previous query return value
+
+drop table skipgram;
+create table skipgram as
+select
+ train_word2vec(
+ r.negative_table,
+ l.words,
+ "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
+ )
+from
+ train_docs l
+ cross join negative_table r
+;
+```
+
+When word is treated as int istead of string,
+you may need to transform wordid of int to word of string by `join` statement.
+
+```sql
+drop table skipgram;
+
+create table skipgram as
+select
+ r.word, t.i, t.wi
+from (
+ select
+ train_word2vec(
+ r.negative_table,
+ l.wordsint,
+ "-n 418953 -win 5 -neg 15 -iter 5"
--- End diff --
`-n 418953` should be `-n ${n}`
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13429055/badge)](https://coveralls.io/builds/13429055)
Coverage decreased (-0.6%) to 40.372% when pulling **8a42adf3687f8c823f255fcbd0c7ff7962f81f0b on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543757
--- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class CBoWModel extends AbstractWord2VecModel {
+ protected CBoWModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
+ }
+
+ protected void trainOnDoc(@Nonnull final int[] doc) {
+ final int vecDim = dim;
+ final int numNegative = neg;
+ final PRNG _rnd = rnd;
+ final Int2FloatOpenHashTable _S = S;
+ final int[] _aliasWordId = aliasWordId;
+ float label, gradient;
+
+ // reuse instance
+ int windowSize, k, numContext, targetWord, inWord, positiveWord;
+
+ updateLearningRate();
+
+ int docLength = doc.length;
+ for (int t = 0; t < iter; t++) {
+ for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) {
+ windowSize = _rnd.nextInt(win) + 1;
+
+ numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize)
+ + Math.min(0, docLength - positiveWordPosition - windowSize - 1);
+
+ float[] gradVec = new float[vecDim];
+ float[] averageVec = new float[vecDim];
+
+ // collect context words
+ for (int contextPosition = positiveWordPosition - windowSize; contextPosition < positiveWordPosition
+ + windowSize + 1; contextPosition++) {
+ if (contextPosition == positiveWordPosition || contextPosition < 0
+ || contextPosition >= docLength) {
+ continue;
+ }
+
+ inWord = doc[contextPosition];
+
+ // average vector of input word vectors
+ if (!inputWeights.containsKey(inWord * vecDim)) {
+ initWordWeights(inWord);
+ }
+
+ for (int i = 0; i < vecDim; i++) {
+ averageVec[i] += inputWeights.get(inWord * vecDim + i) / numContext;
+ }
+ }
+ positiveWord = doc[positiveWordPosition];
+ // negative sampling
+ for (int d = 0; d < numNegative + 1; d++) {
+ if (d == 0) {
+ targetWord = positiveWord;
+ label = 1.f;
+ } else {
+ do {
+ k = _rnd.nextInt(_S.size());
+ if (_S.get(k) > _rnd.nextDouble()) {
+ targetWord = k;
+ } else {
+ targetWord = _aliasWordId[k];
+ }
+ } while (targetWord == positiveWord);
+ label = 0.f;
+ }
+
+ gradient = grad(label, averageVec, targetWord) * lr;
+ for (int i = 0; i < vecDim; i++) {
+ gradVec[i] += gradient * contextWeights.get(targetWord * vecDim + i);
+ contextWeights.put(targetWord * vecDim + i,
+ contextWeights.get(targetWord * vecDim + i) + gradient * averageVec[i]);
+ }
+ }
+
+ // update inWord vector
+ for (int contextPosition = positiveWordPosition - windowSize; contextPosition < positiveWordPosition
+ + windowSize + 1; contextPosition++) {
+ if (contextPosition == positiveWordPosition || contextPosition < 0
+ || contextPosition >= docLength) {
+ continue;
+ }
+
+ inWord = doc[contextPosition];
+ for (int i = 0; i < vecDim; i++) {
+ inputWeights.put(inWord * vecDim + i, inputWeights.get(inWord * vecDim + i)
+ + gradVec[i]);
+ }
+ }
+ }
+ }
+
+ wordCount += docLength * iter;
+ }
+
+ private float grad(final float label, @Nonnull final float[] w, final int c) {
+ float dotValue = 0.f;
+ for (int i = 0; i < dim; i++) {
+ dotValue += w[i] * contextWeights.get(c * dim + i);
+ }
+
+ return (label - sigmoid(dotValue, sigmoidTable));
--- End diff --
remove redundant outermost `(`.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13472784/badge)](https://coveralls.io/builds/13472784)
Coverage decreased (-0.6%) to 40.508% when pulling **0b163fade6f2d26ce918211c94a78c9a3b648cbe on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141550040
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
--- End diff --
What's `10578`?
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r140413586
--- Diff: core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.unsupervised;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class AbstractWord2vecModel {
--- End diff --
Please rename `Word2vec` to `Word2Vec` as seen in [spark](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala).
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545257
--- Diff: docs/gitbook/embedding/word2vec.md ---
@@ -0,0 +1,399 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Word Embedding is a powerful tool for many tasks,
+e.g. finding similar words,
+feature vectors for supervised machine learning task and word analogy,
+such as `king - man + woman =~ queen`.
+In word embedding,
+each word represents a low dimension and dense vector.
+**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
+
+The papers introduce the method are as follows:
+
+- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
+](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
+- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
+
+Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
+Hivemall enables you to train your sequence data such as,
+but not limited to, documents based on word2vec.
+This article gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.? or later.
+
+# Prepare document data
+
+Assume that you already have `docs` table which contains many documents as string format with unique index:
+
+```sql
+select * FROM docs;
+```
+
+| docId | doc |
+|:----: |:----|
+| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
+| ... | ... |
+
+First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
+
+```sql
+drop table docs_words;
+create table docs_words as
+ select
+ docid,
+ tokenize(doc, true) as words
+ FROM
+ docs
+;
+```
+
+This table shows tokenized document.
+
+| docId | doc |
+|:----: |:----|
+| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
+| ... | ... |
+
+Then, you count frequency up per word and remove low frequency words from the vocabulary.
+To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
+
+```sql
+set hivevar:mincount=5;
+
+drop table freq;
+create table freq as
+select
+ row_number() over () - 1 as wordid,
+ word,
+ freq
+from (
+ select
+ word,
+ COUNT(*) as freq
+ from
+ docs_words
+ LATERAL VIEW explode(words) lTable as word
+ group by
+ word
+) t
+where freq >= ${mincount}
+;
+```
+
+Hivemall's word2vec supports two type words; string and int.
+String type tends to use huge memory during training.
+On the other hand, int type tends to use less memory.
+If you train on small dataset, we recommend using string type,
+because memory usage can be ignored and HiveQL is more simple.
+If you train on large dataset, we recommend using int type,
+because it saves memory during training.
+
+# Create sub-sampling table
+
+Sub-sampling table is stored a sub-sampling probability per word.
+
+The sub-sampling probability of word $$w_i$$ is computed by the following equation:
+
+$$
+\begin{aligned}
+f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
+\end{aligned}
+$$
+
+During word2vec training,
+not sub-sampled words are ignored.
+It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
+The smaller `sample` value set,
+the fewer words are used during training.
+
+```sql
+set hivevar:sample=1e-4;
+
+drop table subsampling_table;
+create table subsampling_table as
+with stats as (
+ select
+ sum(freq) as numTrainWords
+ FROM
+ freq
+)
+select
+ l.wordid,
+ l.word,
+ sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
+from
+ freq l
+cross join
+ stats r
+;
+```
+
+```sql
+select * FROM subsampling_table order by p;
+```
+
+| wordid | word | p |
+|:----: | :----: |:----:|
+| 48645 | the | 0.04013665|
+| 11245 | of | 0.052463654|
+| 16368 | and | 0.06555538|
+| 61938 | 00 | 0.068162076|
+| 19977 | in | 0.071441144|
+| 83599 | 0 | 0.07528994|
+| 95017 | a | 0.07559573|
+| 1225 | to | 0.07953133|
+| 37062 | 0000 | 0.08779001|
+| 58246 | is | 0.09049763|
+| ... | ... |... |
+
+The first row shows that 4% of `the` are used in the documents during training.
+
+# Delete low frequency words and high frequency words from `docs_words`
+
+To reduce useless words from corpus,
+low frequency words and high frequency words are deleted.
+And, to avoid loading long document on memory, a document is split into some sub-documents.
+
+```sql
+set hivevar:maxlength=1500;
+SET hivevar:seed=31;
+
+drop table train_docs;
+create table train_docs as
+ with docs_exploded as (
+ select
+ docid,
+ word,
+ pos % ${maxlength} as pos,
+ pos div ${maxlength} as splitid,
+ rand(${seed}) as rnd
+ from
+ docs_words LATERAL VIEW posexplode(words) t as pos, word
+ )
+select
+ l.docid,
+ -- to_ordered_list(l.word, l.pos) as words
+ to_ordered_list(r2.wordid, l.pos) as words,
+from
+ docs_exploded l
+ LEFT SEMI join freq r on (l.word = r.word)
+ join subsampling_table r2 on (l.word = r2.word)
+where
+ r2.p > l.rnd
+group by
+ l.docid, l.splitid
+;
+```
+
+If you store string word in `train_docs` table,
+please replace `to_ordered_list(r2.wordid, l.pos) as words` with `to_ordered_list(l.word, l.pos) as words`.
+
+# Create negative sampling table
+
+Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
+Here, `negative_table` is used to store word sampling probability for negative sampling.
+`z` is a hyperparameter of noise distribution for negative sampling.
+During word2vec training,
+words sampled from this distribution are used for negative examples.
+Noise distribution is the unigram distribution raised to the 3/4rd power.
+
+$$
+\begin{aligned}
+p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
+\end{aligned}
+$$
+
+To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
+Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
+
+This method has proposed in papers below:
+
+- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
+- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
+
+```sql
+set hivevar:z=3/4;
+
+drop table negative_table;
+create table negative_table as
+select
+ collect_list(array(word, p, other)) as negative_table
+from (
+ select
+ alias_table(to_map(word, negative)) as (word, p, other)
+ from
+ (
+ select
+ word,
+ -- wordid as word,
+ pow(freq, ${z}) as negative
+ from
+ freq
+ ) t
+) t1
+;
+```
+
+`alias_table` function returns the records like following.
+
+| word | p | other |
+|:----: | :----: |:----:|
+| leopold | 0.6556492 | 0000 |
+| slep | 0.09060383 | leopold |
+| valentinian | 0.76077825 | belarusian |
+| slew | 0.90569097 | colin |
+| lucien | 0.86329675 | overland |
+| equitable | 0.7270946 | farms |
+| insurers | 0.2367955 | israel |
+| lucier | 0.14855136 | supplements |
+| lieve | 0.12075222 | separatist |
+| skyhawks | 0.14079945 | steamed |
+| ... | ... | ... |
+
+To sample negative word from this `negative_table`,
+
+1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
+2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
+3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
+
+Here, to use it in training function of word2vec,
+`alias_table`'s return records are stored into one list in the `negative_table`.
+
+# Train word2vec
+
+Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
+The default model is `"skipgram"`.
+
+> #### Note
+> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
+
+## Train Skip-Gram
+
+In skip-gram model,
+word vectors are trained to predict the nearby words.
+For example, given a sentence like a `"alice", "was", "beginning", "to"`,
+`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
+
+```sql
+select sum(size(words)) from train_docs;
+set hivevar:n=418953; -- previous query return value
+
+drop table skipgram;
+create table skipgram as
+select
+ train_word2vec(
+ r.negative_table,
+ l.words,
+ "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
+ )
+from
+ train_docs l
+ cross join negative_table r
+;
+```
+
+When word is treated as int istead of string,
+you may need to transform wordid of int to word of string by `join` statement.
+
+```sql
+drop table skipgram;
+
+create table skipgram as
+select
+ r.word, t.i, t.wi
+from (
+ select
+ train_word2vec(
+ r.negative_table,
+ l.wordsint,
+ "-n 418953 -win 5 -neg 15 -iter 5"
+ ) as (wordid, i, wi)
+ from
+ train_docs l
+ cross join
+ negative_table r
+) t
+join freq r on (t.wordid = r.wordid)
+;
+```
+
+## Train CBoW
+
+In CBoW model,
--- End diff --
remove line break after `,`
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546846
--- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class CBoWModel extends AbstractWord2VecModel {
+ protected CBoWModel(final int dim, final int win, final int neg, final int iter,
--- End diff --
add a blank line before constructor.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543245
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.inputWeights.defaultReturnValue(0.f);
+ this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.contextWeights.defaultReturnValue(0.f);
+ }
+
+ private static float[] initSigmoidTable() {
+ float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE];
+ for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) {
+ float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID;
+ sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
+ }
+ return sigmoidTable;
+ }
+
+ protected void initWordWeights(final int wordId) {
+ for (int i = 0; i < dim; i++) {
+ inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim);
+ }
+ }
+
+ protected static float sigmoid(final float v, final float[] sigmoidTable) {
+ if (v > MAX_SIGMOID) {
+ return 1.f;
+ } else if (v < -MAX_SIGMOID) {
+ return 0.f;
+ } else {
+ return sigmoidTable[(int) ((v + MAX_SIGMOID) * (SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2))];
+ }
+ }
+
+ protected void updateLearningRate() {
+ // TODO: valid lr?
--- End diff --
remove this TODO comment and blank lines.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547708
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
+ "learningRate",
+ true,
+ "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
+
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+ int win = 5;
+ int neg = 5;
+ int iter = 5;
+ int dim = 100;
+ long numTrainWords = 0L;
+ String modelName = "skipgram";
+ float lr = 0.025f;
+
+ if (argOIs.length >= 3) {
+ String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ cl = parseOptions(rawArgs);
+
+ numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
+ if (numTrainWords <= 0) {
+ throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
+ + numTrainWords);
+ }
+
+ dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
+ if (dim <= 0.d) {
+ throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
+ }
+
+ win = Primitives.parseInt(cl.getOptionValue("win"), win);
+ if (win <= 0) {
+ throw new UDFArgumentException("Argument `int win` must be positive: " + win);
+ }
+
+ neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
+ if (neg < 0) {
+ throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
+ }
+
+ iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
+ if (iter <= 0) {
+ throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
+ }
+
+ modelName = cl.getOptionValue("model", modelName);
+ if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
--- End diff --
`"skipgram".equals(modelName)` is null safe.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13431616/badge)](https://coveralls.io/builds/13431616)
Coverage decreased (-0.9%) to 40.049% when pulling **f19d732122d29a00e421d7f1ad0d0ca93f242c1e on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13452116/badge)](https://coveralls.io/builds/13452116)
Coverage decreased (-0.9%) to 40.065% when pulling **2415589bb3a8eda3e23a765b10e01fde6c70a298 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543209
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.inputWeights.defaultReturnValue(0.f);
+ this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.contextWeights.defaultReturnValue(0.f);
+ }
+
+ private static float[] initSigmoidTable() {
+ float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE];
+ for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) {
+ float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID;
+ sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
+ }
+ return sigmoidTable;
+ }
+
+ protected void initWordWeights(final int wordId) {
+ for (int i = 0; i < dim; i++) {
+ inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim);
+ }
+ }
+
+ protected static float sigmoid(final float v, final float[] sigmoidTable) {
--- End diff --
`@Nonnull` for sigmoidTable
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13372775/badge)](https://coveralls.io/builds/13372775)
Coverage decreased (-0.8%) to 40.149% when pulling **39d11236d100a92d54cc46d8ffce4bd89670217f on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13366601/badge)](https://coveralls.io/builds/13366601)
Coverage decreased (-0.7%) to 40.181% when pulling **7a5fd547caef5d1af512422dc75dd0efdf5b9466 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
`What type of PR is it? => Improvement` should be `Feature`.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13450431/badge)](https://coveralls.io/builds/13450431)
Coverage decreased (-0.9%) to 40.065% when pulling **d12ba32469a35653ee6a499157ad236bfcaa21cf on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545448
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
+ "learningRate",
+ true,
+ "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
+
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+ int win = 5;
+ int neg = 5;
+ int iter = 5;
+ int dim = 100;
+ long numTrainWords = 0L;
+ String modelName = "skipgram";
+ float lr = 0.025f;
+
+ if (argOIs.length >= 3) {
+ String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ cl = parseOptions(rawArgs);
+
+ numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
+ if (numTrainWords <= 0) {
+ throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
+ + numTrainWords);
+ }
+
+ dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
+ if (dim <= 0.d) {
+ throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
+ }
+
+ win = Primitives.parseInt(cl.getOptionValue("win"), win);
+ if (win <= 0) {
+ throw new UDFArgumentException("Argument `int win` must be positive: " + win);
+ }
+
+ neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
+ if (neg < 0) {
+ throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
+ }
+
+ iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
+ if (iter <= 0) {
+ throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
+ }
+
+ modelName = cl.getOptionValue("model", modelName);
+ if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
+ throw new UDFArgumentException("Argument `string model` must be skipgram or cbow: "
+ + modelName);
+ }
+
+ if (modelName.equals("cbow")) {
+ lr = 0.05f;
+ }
+
+ lr = Primitives.parseFloat(cl.getOptionValue("lr"), lr);
+ if (lr <= 0.f) {
+ throw new UDFArgumentException("Argument `float lr` must be positive: " + lr);
+ }
+ }
+
+ this.numTrainWords = numTrainWords;
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.skipgram = modelName.equals("skipgram");
+ this.startingLR = lr;
+ return cl;
+ }
+
+ public void close() throws HiveException {
+ if (model != null) {
+ forwardModel();
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ }
+ }
+
+ private void forwardModel() throws HiveException {
+ if (isStringInput) {
+ final Text word = new Text();
+ final IntWritable dimIndex = new IntWritable();
+ final FloatWritable value = new FloatWritable();
+
+ final Object[] result = new Object[3];
+ result[0] = word;
+ result[1] = dimIndex;
+ result[2] = value;
+
+ IMapIterator<String, Integer> iter = word2index.entries();
+ while (iter.next() != -1) {
+ int wordId = iter.getValue();
+ if (!model.inputWeights.containsKey(wordId * dim)){
+ continue;
+ }
+
+ word.set(iter.getKey());
+
+ for (int i = 0; i < dim; i++) {
+ dimIndex.set(i);
+ value.set(model.inputWeights.get(wordId * dim + i));
+ forward(result);
+ }
+ }
+ } else {
+ final IntWritable word = new IntWritable();
+ final IntWritable dimIndex = new IntWritable();
+ final FloatWritable value = new FloatWritable();
+
+ final Object[] result = new Object[3];
+ result[0] = word;
+ result[1] = dimIndex;
+ result[2] = value;
+
+ for (int wordId = 0; wordId < aliasWordIds.length; wordId++) {
+ if (!model.inputWeights.containsKey(wordId * dim)){
+ break;
+ }
+ word.set(wordId);
+ for (int i = 0; i < dim; i++) {
+ dimIndex.set(i);
+ value.set(model.inputWeights.get(wordId * dim + i));
+ forward(result);
+ }
+ }
+ }
+ }
+
+ private int getWordId(@Nonnull final String word) {
+ if (word2index.containsKey(word)) {
+ return word2index.get(word);
+ } else {
+ int w = word2index.size();
+ word2index.put(word, w);
+ return w;
+ }
+ }
+
+ private void parseNegativeTable(@Nonnull Object listObj) {
+ int aliasSize = negativeTableOI.getListLength(listObj);
+ Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(aliasSize);
+ int[] aliasWordIds = new int[aliasSize];
+
--- End diff --
put `final` to ^ 3 local variables unchanged in the loop.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556886
--- Diff: core/src/main/java/hivemall/embedding/SkipGramModel.java ---
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class SkipGramModel extends AbstractWord2VecModel {
+ protected SkipGramModel(final int dim, final int win, final int neg, final int iter,
--- End diff --
Lot's of hyperparameters in constructor.
Consider using Hyperparameter class as seen in https://github.com/apache/incubator-hivemall/blob/master/core/src/main/java/hivemall/fm/FMHyperParameters.java
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542968
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
--- End diff --
`@Nonnegative` for each variable (win, neg, iter).
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556621
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
--- End diff --
remain `learningRate` for longOpt and use `eta0` for initialLearningRate.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
@myui I resolved conflicts.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13449529/badge)](https://coveralls.io/builds/13449529)
Coverage decreased (-0.9%) to 40.061% when pulling **af5b5becb55e58fd9db355be218035518973844a on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13387498/badge)](https://coveralls.io/builds/13387498)
Coverage decreased (-0.8%) to 40.077% when pulling **bbdb561cc1bf3194128034fe194555bb6c167144 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13429656/badge)](https://coveralls.io/builds/13429656)
Coverage decreased (-0.6%) to 40.372% when pulling **d1b4270861c277109f35c8675e8297ef081f1dee on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13454878/badge)](https://coveralls.io/builds/13454878)
Coverage decreased (-0.5%) to 40.383% when pulling **da564b8cea0bd028d3f822ed750513cd28ff45c7 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556391
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
--- End diff --
2^n or 1024 * 10 is more understandable.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r140725317
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected final int MAX_SIGMOID = 6;
+ protected final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+ @Nonnegative
+ private long wordCountActual;
+
+ @Nonnegative
+ protected int dim;
+ private PRNG _rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+
+ protected AbstractWord2VecModel(final int dim, final float startingLR, final long numTrainWords) {
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.wordCountActual = 0L;
+ this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable(MAX_SIGMOID, SIGMOID_TABLE_SIZE);
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.inputWeights.defaultReturnValue(-0.f);
+ this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.contextWeights.defaultReturnValue(0.f);
+ }
+
+ private static float[] initSigmoidTable(final int maxSigmoid, final int sigmoidTableSize) {
+ float[] sigmoidTable = new float[sigmoidTableSize];
+ for (int i = 0; i < sigmoidTableSize; i++) {
+ float x = ((float) i / sigmoidTableSize * 2 - 1) * (float) maxSigmoid;
+ sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
+ }
+ return sigmoidTable;
+ }
+
+ protected void initWordWeights(final int wordId) {
+ for (int i = 0; i < dim; i++) {
+ inputWeights.put(wordId * dim + i, ((float) _rnd.nextDouble() - 0.5f) / dim);
+ }
+ }
+
+ protected static float sigmoid(final float v, final int MAX_SIGMOID,
--- End diff --
No need to use Constants for argument: `final int MAX_SIGMOID, final int SIGMOID_TABLE_SIZE`
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13411817/badge)](https://coveralls.io/builds/13411817)
Coverage decreased (-0.9%) to 40.029% when pulling **e0945527c68da9e2c9ab6eb86a7eb7d66bf42aa7 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543643
--- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class CBoWModel extends AbstractWord2VecModel {
+ protected CBoWModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
+ }
+
+ protected void trainOnDoc(@Nonnull final int[] doc) {
+ final int vecDim = dim;
+ final int numNegative = neg;
+ final PRNG _rnd = rnd;
+ final Int2FloatOpenHashTable _S = S;
+ final int[] _aliasWordId = aliasWordId;
+ float label, gradient;
+
+ // reuse instance
+ int windowSize, k, numContext, targetWord, inWord, positiveWord;
+
+ updateLearningRate();
+
+ int docLength = doc.length;
+ for (int t = 0; t < iter; t++) {
+ for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) {
+ windowSize = _rnd.nextInt(win) + 1;
+
+ numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize)
+ + Math.min(0, docLength - positiveWordPosition - windowSize - 1);
+
+ float[] gradVec = new float[vecDim];
--- End diff --
add `final` for `gradVec` and `averageVec`.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13472440/badge)](https://coveralls.io/builds/13472440)
Coverage decreased (-0.6%) to 40.505% when pulling **8696f5ff668adf758d3545bab5885e51ce7d053e on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543514
--- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public final class CBoWModel extends AbstractWord2VecModel {
+ protected CBoWModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
+ }
+
+ protected void trainOnDoc(@Nonnull final int[] doc) {
+ final int vecDim = dim;
+ final int numNegative = neg;
+ final PRNG _rnd = rnd;
+ final Int2FloatOpenHashTable _S = S;
--- End diff --
Member variable should be `_S` and local variable should be `S`.
`_rnd`, `_aliasWordId` as well.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546656
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected static final int MAX_SIGMOID = 6;
+ protected static final int SIGMOID_TABLE_SIZE = 1000;
+ protected float[] sigmoidTable;
+
+
+ @Nonnegative
+ protected int dim;
+ protected int win;
+ protected int neg;
+ protected int iter;
+
+ // learning rate parameters
+ @Nonnegative
+ protected float lr;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ @Nonnegative
+ protected long wordCount;
+ @Nonnegative
+ private long lastWordCount;
+
+ protected PRNG rnd;
+
+ protected Int2FloatOpenHashTable contextWeights;
+ protected Int2FloatOpenHashTable inputWeights;
+ protected Int2FloatOpenHashTable S;
+ protected int[] aliasWordId;
+
+ protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
+ final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
+ final int[] aliasWordId) {
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.startingLR = this.lr = startingLR;
+ this.numTrainWords = numTrainWords;
+
+ // alias sampler for negative sampling
+ this.S = S;
+ this.aliasWordId = aliasWordId;
+
+ this.wordCount = 0L;
+ this.lastWordCount = 0L;
+ this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
+
+ this.sigmoidTable = initSigmoidTable();
+
+ // TODO how to estimate size
+ this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.inputWeights.defaultReturnValue(0.f);
+ this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
+ this.contextWeights.defaultReturnValue(0.f);
+ }
+
+ private static float[] initSigmoidTable() {
--- End diff --
`@Nonnull` for return
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13415230/badge)](https://coveralls.io/builds/13415230)
Coverage decreased (-0.9%) to 40.028% when pulling **4abdb8f3d2632baa9b3ead928bc8bb0283027e20 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547506
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
--- End diff --
consistent naming `eta0`, `learningRate` for the initial learning rate.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545805
--- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java ---
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.Int2IntOpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.Queue;
+import java.util.ArrayDeque;
+
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnull;
+
+public final class AliasTableBuilderUDTF extends GenericUDTF {
+ private MapObjectInspector negativeTableOI;
+ private PrimitiveObjectInspector negativeTableKeyOI;
+ private PrimitiveObjectInspector negativeTableValueOI;
+
+ private int numVocab;
+ private List<String> index2word;
+ private Int2IntOpenHashTable A;
+ private Int2FloatOpenHashTable S;
+ private boolean isIntElement;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ if (!(argOIs.length >= 1)) {
+ throw new UDFArgumentException(
+ "_FUNC_(map<string, double>) takes at least one argument");
+ }
+
+ this.negativeTableOI = HiveUtils.asMapOI(argOIs[0]);
+ this.negativeTableValueOI = HiveUtils.asFloatingPointOI(negativeTableOI.getMapValueObjectInspector());
+
+ boolean isIntEmelentOI = HiveUtils.isIntOI((negativeTableOI.getMapKeyObjectInspector()));
+
+ if (isIntEmelentOI) {
+ this.negativeTableKeyOI = HiveUtils.asIntCompatibleOI(negativeTableOI.getMapKeyObjectInspector());
+ } else {
+ this.negativeTableKeyOI = HiveUtils.asStringOI(negativeTableOI.getMapKeyObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+ fieldNames.add("word");
+
+ if (isIntEmelentOI) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ }
+
+ fieldNames.add("p");
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ fieldNames.add("other");
+ if (isIntEmelentOI) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ }
+
+ this.isIntElement = isIntEmelentOI;
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (!isIntElement) {
+ index2word = new ArrayList<>();
+ }
+
+ final List<Float> unnormalizedProb = new ArrayList<>();
+ int numVocab = 0;
+ float denom = 0.f;
+ for (Map.Entry<?, ?> entry : negativeTableOI.getMap(args[0]).entrySet()) {
+ if (!isIntElement) {
+ String word = PrimitiveObjectInspectorUtils.getString(entry.getKey(),
+ negativeTableKeyOI);
+ index2word.add(word);
+ }
+
+ float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), negativeTableValueOI);
+ unnormalizedProb.add(v);
+ denom += v;
+ numVocab++;
+ }
+
+ this.numVocab = numVocab;
+ createAliasTable(numVocab, denom, unnormalizedProb);
+ }
+
+ private void createAliasTable(final int V, final float denom,
+ final @Nonnull List<Float> unnormalizedProb) {
+ Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(V);
+ Int2IntOpenHashTable A = new Int2IntOpenHashTable(V);
+
+ final Queue<Integer> higherBin = new ArrayDeque<>();
+ final Queue<Integer> lowerBin = new ArrayDeque<>();
+
+ for (int i = 0; i < V; i++) {
+ float v = V * unnormalizedProb.get(i) / denom;
+ S.put(i, v);
+ if (v > 1.f) {
+ higherBin.add(i);
+ } else {
+ lowerBin.add(i);
+ }
+ }
+
+ while (lowerBin.size() > 0 && higherBin.size() > 0) {
+ int low = lowerBin.remove();
+ int high = higherBin.remove();
+ A.put(low, high);
+ S.put(high, S.get(high) - 1.f + S.get(low));
+ if (S.get(high) < 1.f) {
+ lowerBin.add(high);
+ } else {
+ higherBin.add(high);
+ }
+ }
+ this.A = A;
+ this.S = S;
+ }
+
+ @Override
+ public void close() throws HiveException {
+ if (isIntElement) {
+ IntWritable word = new IntWritable();
+ FloatWritable pro = new FloatWritable();
+ IntWritable otherWord = new IntWritable();
+
+ Object[] res = new Object[3];
+ res[0] = word;
+ res[1] = pro;
+ res[2] = otherWord;
+
+ for (int i = 0; i < numVocab; i++) {
+ word.set(i);
+ pro.set(S.get(i));
+ if (A.get(i) == -1) {
+ otherWord.set(0);
+ } else {
+ otherWord.set(A.get(i));
+ }
+ forward(res);
+ }
+ } else {
+ Text word = new Text();
+ FloatWritable pro = new FloatWritable();
+ Text otherWord = new Text();
+
+ Object[] res = new Object[3];
--- End diff --
put `final` for local variables that are unchanged in the for loop.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13365085/badge)](https://coveralls.io/builds/13365085)
Coverage decreased (-0.7%) to 40.18% when pulling **7014e8552f81d59ab35c95d8fcf54c56c24ba2c9 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547369
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
--- End diff --
consistent naming `"iters", "iterations"` as seen in SLIM.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544782
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
+ }
+ }
+
+ model.trainOnDoc(doc);
+ }
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("n", "numTrainWords", true,
+ "The number of words in the documents. It is used to update learning rate");
+ opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
+ opts.addOption("win", "window", true, "Context window size [default: 5]");
+ opts.addOption("neg", "negative", true,
+ "The number of negative sampled words per word [default: 5]");
+ opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
+ opts.addOption("model", "modelName", true,
+ "The model name of word2vec: skipgram or cbow [default: skipgram]");
+ opts.addOption(
+ "lr",
+ "learningRate",
+ true,
+ "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
+
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
+ CommandLine cl = null;
+ int win = 5;
+ int neg = 5;
+ int iter = 5;
+ int dim = 100;
+ long numTrainWords = 0L;
+ String modelName = "skipgram";
+ float lr = 0.025f;
+
+ if (argOIs.length >= 3) {
+ String rawArgs = HiveUtils.getConstString(argOIs[2]);
+ cl = parseOptions(rawArgs);
+
+ numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
+ if (numTrainWords <= 0) {
+ throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
+ + numTrainWords);
+ }
+
+ dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
+ if (dim <= 0.d) {
+ throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
+ }
+
+ win = Primitives.parseInt(cl.getOptionValue("win"), win);
+ if (win <= 0) {
+ throw new UDFArgumentException("Argument `int win` must be positive: " + win);
+ }
+
+ neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
+ if (neg < 0) {
+ throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
+ }
+
+ iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
+ if (iter <= 0) {
+ throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
+ }
+
+ modelName = cl.getOptionValue("model", modelName);
+ if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
+ throw new UDFArgumentException("Argument `string model` must be skipgram or cbow: "
+ + modelName);
+ }
+
+ if (modelName.equals("cbow")) {
+ lr = 0.05f;
+ }
+
+ lr = Primitives.parseFloat(cl.getOptionValue("lr"), lr);
+ if (lr <= 0.f) {
+ throw new UDFArgumentException("Argument `float lr` must be positive: " + lr);
+ }
+ }
+
+ this.numTrainWords = numTrainWords;
+ this.win = win;
+ this.neg = neg;
+ this.iter = iter;
+ this.dim = dim;
+ this.skipgram = modelName.equals("skipgram");
+ this.startingLR = lr;
+ return cl;
+ }
+
+ public void close() throws HiveException {
+ if (model != null) {
+ forwardModel();
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ }
+ }
+
+ private void forwardModel() throws HiveException {
+ if (isStringInput) {
+ final Text word = new Text();
+ final IntWritable dimIndex = new IntWritable();
+ final FloatWritable value = new FloatWritable();
+
+ final Object[] result = new Object[3];
+ result[0] = word;
+ result[1] = dimIndex;
+ result[2] = value;
+
+ IMapIterator<String, Integer> iter = word2index.entries();
+ while (iter.next() != -1) {
+ int wordId = iter.getValue();
+ if (!model.inputWeights.containsKey(wordId * dim)){
+ continue;
+ }
+
+ word.set(iter.getKey());
+
+ for (int i = 0; i < dim; i++) {
+ dimIndex.set(i);
+ value.set(model.inputWeights.get(wordId * dim + i));
+ forward(result);
+ }
+ }
+ } else {
+ final IntWritable word = new IntWritable();
+ final IntWritable dimIndex = new IntWritable();
+ final FloatWritable value = new FloatWritable();
+
+ final Object[] result = new Object[3];
+ result[0] = word;
+ result[1] = dimIndex;
+ result[2] = value;
+
+ for (int wordId = 0; wordId < aliasWordIds.length; wordId++) {
+ if (!model.inputWeights.containsKey(wordId * dim)){
+ break;
+ }
+ word.set(wordId);
+ for (int i = 0; i < dim; i++) {
+ dimIndex.set(i);
+ value.set(model.inputWeights.get(wordId * dim + i));
+ forward(result);
+ }
+ }
+ }
+ }
+
+ private int getWordId(@Nonnull final String word) {
+ if (word2index.containsKey(word)) {
--- End diff --
`word2index` is not ensured to be non-null.
```java
private static int getWordId(@Nonnull final String word, @CheckNotNull OpenHashTable<String, Integer> word2Index) {
Precondition.checkNotNull(word2index);
```
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13411195/badge)](https://coveralls.io/builds/13411195)
Coverage decreased (-0.5%) to 40.383% when pulling **7a2f4dbfeb89eee78e6e0526027b5b0ba9162c29 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13438117/badge)](https://coveralls.io/builds/13438117)
Coverage decreased (-0.8%) to 40.107% when pulling **f0abd4fd99dcace050d155533ebc2dc8768cfc79 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545219
--- Diff: docs/gitbook/embedding/word2vec.md ---
@@ -0,0 +1,399 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Word Embedding is a powerful tool for many tasks,
+e.g. finding similar words,
+feature vectors for supervised machine learning task and word analogy,
+such as `king - man + woman =~ queen`.
+In word embedding,
+each word represents a low dimension and dense vector.
+**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
+
+The papers introduce the method are as follows:
+
+- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
+](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
+- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
+
+Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
+Hivemall enables you to train your sequence data such as,
+but not limited to, documents based on word2vec.
+This article gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.? or later.
+
+# Prepare document data
+
+Assume that you already have `docs` table which contains many documents as string format with unique index:
+
+```sql
+select * FROM docs;
+```
+
+| docId | doc |
+|:----: |:----|
+| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
+| ... | ... |
+
+First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
+
+```sql
+drop table docs_words;
+create table docs_words as
+ select
+ docid,
+ tokenize(doc, true) as words
+ FROM
+ docs
+;
+```
+
+This table shows tokenized document.
+
+| docId | doc |
+|:----: |:----|
+| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
+| ... | ... |
+
+Then, you count frequency up per word and remove low frequency words from the vocabulary.
+To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
+
+```sql
+set hivevar:mincount=5;
+
+drop table freq;
+create table freq as
+select
+ row_number() over () - 1 as wordid,
+ word,
+ freq
+from (
+ select
+ word,
+ COUNT(*) as freq
+ from
+ docs_words
+ LATERAL VIEW explode(words) lTable as word
+ group by
+ word
+) t
+where freq >= ${mincount}
+;
+```
+
+Hivemall's word2vec supports two type words; string and int.
+String type tends to use huge memory during training.
+On the other hand, int type tends to use less memory.
+If you train on small dataset, we recommend using string type,
+because memory usage can be ignored and HiveQL is more simple.
+If you train on large dataset, we recommend using int type,
+because it saves memory during training.
+
+# Create sub-sampling table
+
+Sub-sampling table is stored a sub-sampling probability per word.
+
+The sub-sampling probability of word $$w_i$$ is computed by the following equation:
+
+$$
+\begin{aligned}
+f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
+\end{aligned}
+$$
+
+During word2vec training,
+not sub-sampled words are ignored.
+It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
+The smaller `sample` value set,
+the fewer words are used during training.
+
+```sql
+set hivevar:sample=1e-4;
+
+drop table subsampling_table;
+create table subsampling_table as
+with stats as (
+ select
+ sum(freq) as numTrainWords
+ FROM
+ freq
+)
+select
+ l.wordid,
+ l.word,
+ sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
+from
+ freq l
+cross join
+ stats r
+;
+```
+
+```sql
+select * FROM subsampling_table order by p;
+```
+
+| wordid | word | p |
+|:----: | :----: |:----:|
+| 48645 | the | 0.04013665|
+| 11245 | of | 0.052463654|
+| 16368 | and | 0.06555538|
+| 61938 | 00 | 0.068162076|
+| 19977 | in | 0.071441144|
+| 83599 | 0 | 0.07528994|
+| 95017 | a | 0.07559573|
+| 1225 | to | 0.07953133|
+| 37062 | 0000 | 0.08779001|
+| 58246 | is | 0.09049763|
+| ... | ... |... |
+
+The first row shows that 4% of `the` are used in the documents during training.
+
+# Delete low frequency words and high frequency words from `docs_words`
+
+To reduce useless words from corpus,
+low frequency words and high frequency words are deleted.
+And, to avoid loading long document on memory, a document is split into some sub-documents.
+
+```sql
+set hivevar:maxlength=1500;
+SET hivevar:seed=31;
+
+drop table train_docs;
+create table train_docs as
+ with docs_exploded as (
+ select
+ docid,
+ word,
+ pos % ${maxlength} as pos,
+ pos div ${maxlength} as splitid,
+ rand(${seed}) as rnd
+ from
+ docs_words LATERAL VIEW posexplode(words) t as pos, word
+ )
+select
+ l.docid,
+ -- to_ordered_list(l.word, l.pos) as words
+ to_ordered_list(r2.wordid, l.pos) as words,
+from
+ docs_exploded l
+ LEFT SEMI join freq r on (l.word = r.word)
+ join subsampling_table r2 on (l.word = r2.word)
+where
+ r2.p > l.rnd
+group by
+ l.docid, l.splitid
+;
+```
+
+If you store string word in `train_docs` table,
+please replace `to_ordered_list(r2.wordid, l.pos) as words` with `to_ordered_list(l.word, l.pos) as words`.
+
+# Create negative sampling table
+
+Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
+Here, `negative_table` is used to store word sampling probability for negative sampling.
+`z` is a hyperparameter of noise distribution for negative sampling.
+During word2vec training,
+words sampled from this distribution are used for negative examples.
+Noise distribution is the unigram distribution raised to the 3/4rd power.
+
+$$
+\begin{aligned}
+p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
+\end{aligned}
+$$
+
+To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
+Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
+
+This method has proposed in papers below:
+
+- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
+- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
+
+```sql
+set hivevar:z=3/4;
+
+drop table negative_table;
+create table negative_table as
+select
+ collect_list(array(word, p, other)) as negative_table
+from (
+ select
+ alias_table(to_map(word, negative)) as (word, p, other)
+ from
+ (
+ select
+ word,
+ -- wordid as word,
+ pow(freq, ${z}) as negative
+ from
+ freq
+ ) t
+) t1
+;
+```
+
+`alias_table` function returns the records like following.
+
+| word | p | other |
+|:----: | :----: |:----:|
+| leopold | 0.6556492 | 0000 |
+| slep | 0.09060383 | leopold |
+| valentinian | 0.76077825 | belarusian |
+| slew | 0.90569097 | colin |
+| lucien | 0.86329675 | overland |
+| equitable | 0.7270946 | farms |
+| insurers | 0.2367955 | israel |
+| lucier | 0.14855136 | supplements |
+| lieve | 0.12075222 | separatist |
+| skyhawks | 0.14079945 | steamed |
+| ... | ... | ... |
+
+To sample negative word from this `negative_table`,
+
+1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
+2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
+3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
+
+Here, to use it in training function of word2vec,
+`alias_table`'s return records are stored into one list in the `negative_table`.
+
+# Train word2vec
+
+Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
+The default model is `"skipgram"`.
+
+> #### Note
+> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
+
+## Train Skip-Gram
+
+In skip-gram model,
+word vectors are trained to predict the nearby words.
+For example, given a sentence like a `"alice", "was", "beginning", "to"`,
+`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
+
+```sql
+select sum(size(words)) from train_docs;
+set hivevar:n=418953; -- previous query return value
+
+drop table skipgram;
+create table skipgram as
+select
+ train_word2vec(
+ r.negative_table,
+ l.words,
+ "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
+ )
+from
+ train_docs l
+ cross join negative_table r
+;
+```
+
+When word is treated as int istead of string,
+you may need to transform wordid of int to word of string by `join` statement.
+
+```sql
+drop table skipgram;
+
+create table skipgram as
+select
+ r.word, t.i, t.wi
+from (
+ select
+ train_word2vec(
+ r.negative_table,
+ l.wordsint,
+ "-n 418953 -win 5 -neg 15 -iter 5"
+ ) as (wordid, i, wi)
+ from
+ train_docs l
+ cross join
+ negative_table r
+) t
+join freq r on (t.wordid = r.wordid)
+;
+```
+
+## Train CBoW
+
+In CBoW model,
+word vectors are trained to be predicted the nearby words.
+For example, given a sentence like a `"alice", "was", "beginning", "to"`,
+`"alice"` ,`"beginning"` and `"to"` vectors are learnt to predict `"was"` vector.
+
+```sql
+drop table cbow;
+
+create table cbow as
+select
+ train_word2vec(
+ r.negative_table,
+ l.words,
+ "-n 418953 -win 5 -neg 15 -iter 5 -model cbow"
--- End diff --
-n 418953 should be -n ${n}
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545337
--- Diff: docs/gitbook/embedding/word2vec.md ---
@@ -0,0 +1,399 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+Word Embedding is a powerful tool for many tasks,
+e.g. finding similar words,
+feature vectors for supervised machine learning task and word analogy,
+such as `king - man + woman =~ queen`.
+In word embedding,
+each word represents a low dimension and dense vector.
+**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
+
+The papers introduce the method are as follows:
+
+- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
+](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
+- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
+
+Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
+Hivemall enables you to train your sequence data such as,
+but not limited to, documents based on word2vec.
+This article gives usage instructions of the feature.
+
+<!-- toc -->
+
+> #### Note
+> This feature is supported from Hivemall v0.5-rc.? or later.
+
+# Prepare document data
+
+Assume that you already have `docs` table which contains many documents as string format with unique index:
+
+```sql
+select * FROM docs;
+```
+
+| docId | doc |
+|:----: |:----|
+| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
+| ... | ... |
+
+First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
+
+```sql
+drop table docs_words;
+create table docs_words as
+ select
+ docid,
+ tokenize(doc, true) as words
+ FROM
+ docs
+;
+```
+
+This table shows tokenized document.
+
+| docId | doc |
+|:----: |:----|
+| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
+| ... | ... |
+
+Then, you count frequency up per word and remove low frequency words from the vocabulary.
+To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
+
+```sql
+set hivevar:mincount=5;
+
+drop table freq;
+create table freq as
+select
+ row_number() over () - 1 as wordid,
+ word,
+ freq
+from (
+ select
+ word,
+ COUNT(*) as freq
+ from
+ docs_words
+ LATERAL VIEW explode(words) lTable as word
+ group by
+ word
+) t
+where freq >= ${mincount}
+;
+```
+
+Hivemall's word2vec supports two type words; string and int.
+String type tends to use huge memory during training.
+On the other hand, int type tends to use less memory.
+If you train on small dataset, we recommend using string type,
+because memory usage can be ignored and HiveQL is more simple.
+If you train on large dataset, we recommend using int type,
+because it saves memory during training.
+
+# Create sub-sampling table
+
+Sub-sampling table is stored a sub-sampling probability per word.
+
+The sub-sampling probability of word $$w_i$$ is computed by the following equation:
+
+$$
+\begin{aligned}
+f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
+\end{aligned}
+$$
+
+During word2vec training,
--- End diff --
remove line break after `,`.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543986
--- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.UDTFWithOptions;
+import hivemall.utils.collections.IMapIterator;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+import hivemall.utils.collections.maps.OpenHashTable;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import java.util.List;
+import java.util.Arrays;
+import java.util.ArrayList;
+
+@Description(
+ name = "train_word2vec",
+ value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
+public class Word2VecUDTF extends UDTFWithOptions {
+ protected transient AbstractWord2VecModel model;
+ @Nonnegative
+ private float startingLR;
+ @Nonnegative
+ private long numTrainWords;
+ private OpenHashTable<String, Integer> word2index;
+
+ @Nonnegative
+ private int dim;
+ @Nonnegative
+ private int win;
+ @Nonnegative
+ private int neg;
+ @Nonnegative
+ private int iter;
+ private boolean skipgram;
+ private boolean isStringInput;
+
+ private Int2FloatOpenHashTable S;
+ private int[] aliasWordIds;
+
+ private ListObjectInspector negativeTableOI;
+ private ListObjectInspector negativeTableElementListOI;
+ private PrimitiveObjectInspector negativeTableElementOI;
+
+ private ListObjectInspector docOI;
+ private PrimitiveObjectInspector wordOI;
+
+ @Override
+ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
+ final int numArgs = argOIs.length;
+
+ if (numArgs != 3) {
+ throw new UDFArgumentException(getClass().getSimpleName()
+ + " takes 3 arguments: [, constant string options]: "
+ + Arrays.toString(argOIs));
+ }
+
+ processOptions(argOIs);
+
+ this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
+ this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
+ this.docOI = HiveUtils.asListOI(argOIs[1]);
+
+ this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
+
+ if (isStringInput) {
+ this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
+ } else {
+ this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
+ this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
+ }
+
+ List<String> fieldNames = new ArrayList<>();
+ List<ObjectInspector> fieldOIs = new ArrayList<>();
+
+ fieldNames.add("word");
+ fieldNames.add("i");
+ fieldNames.add("wi");
+
+ if (isStringInput) {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
+ } else {
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ }
+
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+ fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
+
+ this.model = null;
+ this.word2index = null;
+ this.S = null;
+ this.aliasWordIds = null;
+
+ return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public void process(Object[] args) throws HiveException {
+ if (model == null) {
+ parseNegativeTable(args[0]);
+ this.model = createModel();
+ }
+
+ List<?> rawDoc = docOI.getList(args[1]);
+
+ // parse rawDoc
+ final int docLength = rawDoc.size();
+ final int[] doc = new int[docLength];
+ if (isStringInput) {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
+ }
+ } else {
+ for (int i = 0; i < docLength; i++) {
+ doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
--- End diff --
`rawDoc.get(i)` may return null.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13383603/badge)](https://coveralls.io/builds/13383603)
Coverage decreased (-0.8%) to 40.138% when pulling **e50756111378e6a173aba7574e73435364ff42d0 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13434811/badge)](https://coveralls.io/builds/13434811)
Coverage decreased (-0.8%) to 40.149% when pulling **2b66e5e3dbf1408c0719735dcbdf05555938f3bb on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
[![Coverage Status](https://coveralls.io/builds/13371457/badge)](https://coveralls.io/builds/13371457)
Coverage decreased (-0.8%) to 40.156% when pulling **ad2b2911b5a6ebb7b43b4981bf5ff4424425a292 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r140413234
--- Diff: core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java ---
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.unsupervised;
--- End diff --
Please move package from `hivemall.unsupervised` to `hivemall.embedding`.
---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on the issue:
https://github.com/apache/incubator-hivemall/pull/116
@nzw0301 Please rebase to master resolving ^ conflicts.
---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/116#discussion_r140725446
--- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.embedding;
+
+import hivemall.math.random.PRNG;
+import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+public abstract class AbstractWord2VecModel {
+ // cached sigmoid function parameters
+ protected final int MAX_SIGMOID = 6;
--- End diff --
Constants should be `static final`.
---