You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@hivemall.apache.org by nzw0301 <gi...@git.apache.org> on 2017/09/21 02:48:50 UTC

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

GitHub user nzw0301 opened a pull request:

    https://github.com/apache/incubator-hivemall/pull/116

    [WIP][HIVEMALL-118] word2vec

    ## What changes were proposed in this pull request?
    
    Add new algorithm: skip-gram with negative sampling (a.k.a word2vec)
    
    ## What type of PR is it?
    
    Improvement
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-118
    
    ## How was this patch tested?
    
    manual tests on EMR
    
    ## How to use this feature?
    
    please see `word2vec.md`
    
    ## Checklist
    
    
    
    - [x] Did you apply source code formatter, i.e., `mvn formatter:format`, for your commit?


You can merge this pull request into a Git repository by running:

    $ git pull https://github.com/nzw0301/incubator-hivemall skipgram

Alternatively you can review and apply these changes as the patch at:

    https://github.com/apache/incubator-hivemall/pull/116.patch

To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:

    This closes #116
    
----
commit f19186fe8eff3de0400cc318c8c876fc69dbc766
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-13T12:50:35Z

    Init docs for word2vec

commit e9a76093efd1803dd63a11697abc68e7ae78cbb8
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-13T12:54:33Z

    Init Alias Table builder

commit b6883b00e8f543130e16caa2ad9010cfcd0e6cd9
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-13T13:02:10Z

    Fix typo

commit 3ca761956638c95264a37edfb49133f35c37cdd5
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-14T02:15:51Z

    Separate calias table function

commit 2588e732b1194d800005c6ac28453adf79d89278
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-14T12:26:49Z

    Create skip-gram UDTF

commit 33900380a337821f2da280d0c7c9c10f2fe14565
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-15T03:35:36Z

    Use float to save memory

commit a7394a04bc699e105febf86a6b9712f8daa8ff92
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-15T07:04:44Z

    Update query example in docs

commit 50d5ffcffff503f1522577b7789e7745650a1ece
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-15T07:05:21Z

    Update forwarding

commit b701919825b3ed8909659b77236d6522f6d51be6
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-18T12:03:55Z

    Init Word2vecFeatureUDTF

commit 7f69ef66ef2849e57d06601fb9098b6bf2c2bb21
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-19T08:53:26Z

    Implement skip-gramfeature UDTF

commit 48c1929b6345d144059eec44636c0194eebc5281
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-19T14:01:46Z

    Update skipgram

commit 7f4abde3760cf17455d99eb6c5a7cbe8c3dc5e3c
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-19T16:05:15Z

    Refactor Skipgram

commit 2e429d60ff2ae473a64cb890e67c9d8b9a7b2830
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-20T05:26:25Z

    Update for query change

commit 7014e8552f81d59ab35c95d8fcf54c56c24ba2c9
Author: Kento NOZAWA <k_...@klis.tsukuba.ac.jp>
Date:   2017-09-20T08:54:51Z

    Remove discard table

----


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13366956/badge)](https://coveralls.io/builds/13366956)
    
    Coverage decreased (-0.4%) to 40.499% when pulling **c224912606f0dfd4d7d53acb0eac3fecc016b335 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13385853/badge)](https://coveralls.io/builds/13385853)
    
    Coverage decreased (-0.5%) to 40.468% when pulling **a3ccaa8a38ecff49d23e94a8cc7db7f895181e06 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544983
  
    --- Diff: docs/gitbook/embedding/word2vec.md ---
    @@ -0,0 +1,399 @@
    +<!--
    +  Licensed to the Apache Software Foundation (ASF) under one
    +  or more contributor license agreements.  See the NOTICE file
    +  distributed with this work for additional information
    +  regarding copyright ownership.  The ASF licenses this file
    +  to you under the Apache License, Version 2.0 (the
    +  "License"); you may not use this file except in compliance
    +  with the License.  You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +  Unless required by applicable law or agreed to in writing,
    +  software distributed under the License is distributed on an
    +  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +  KIND, either express or implied.  See the License for the
    +  specific language governing permissions and limitations
    +  under the License.
    +-->
    +
    +Word Embedding is a powerful tool for many tasks,
    +e.g. finding similar words,
    +feature vectors for supervised machine learning task and word analogy,
    +such as `king - man + woman =~ queen`.
    +In word embedding,
    +each word represents a low dimension and dense vector.
    +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
    +
    +The papers introduce the method are as follows:
    +
    +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
    +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
    +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
    +
    +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
    +Hivemall enables you to train your sequence data such as,
    +but not limited to, documents based on word2vec.
    +This article gives usage instructions of the feature.
    +
    +<!-- toc -->
    +
    +> #### Note
    +> This feature is supported from Hivemall v0.5-rc.? or later.
    +
    +# Prepare document data
    +
    +Assume that you already have `docs` table which contains many documents as string format with unique index:
    +
    +```sql
    +select * FROM docs;
    +```
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
    +|  ...  | ... |
    +
    +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
    +
    +```sql
    +drop table docs_words;
    +create table docs_words as
    +  select
    +    docid,
    +    tokenize(doc, true) as words
    +  FROM
    +    docs
    +;
    +```
    +
    +This table shows tokenized document.
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
    +|  ...  | ... |
    +
    +Then, you count frequency up per word and remove low frequency words from the vocabulary.
    +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
    +
    +```sql
    +set hivevar:mincount=5;
    +
    +drop table freq;
    +create table freq as
    +select
    +  row_number() over () - 1 as wordid,
    +  word,
    +  freq
    +from (
    +  select
    +    word,
    +    COUNT(*) as freq
    +  from
    +    docs_words
    +  LATERAL VIEW explode(words) lTable as word
    +  group by
    +    word
    +) t
    +where freq >= ${mincount}
    +;
    +```
    +
    +Hivemall's word2vec supports two type words; string and int.
    +String type tends to use huge memory during training.
    +On the other hand, int type tends to use less memory.
    +If you train on small dataset, we recommend using string type,
    +because memory usage can be ignored and HiveQL is more simple.
    +If you train on large dataset, we recommend using int type,
    +because it saves memory during training.
    +
    +# Create sub-sampling table
    +
    +Sub-sampling table is stored a sub-sampling probability per word.
    +
    +The sub-sampling probability of word $$w_i$$ is computed by the following equation:
    +
    +$$
    +\begin{aligned}
    +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
    +\end{aligned}
    +$$
    +
    +During word2vec training,
    +not sub-sampled words are ignored.
    +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
    +The smaller `sample` value set,
    +the fewer words are used during training.
    +
    +```sql
    +set hivevar:sample=1e-4;
    +
    +drop table subsampling_table;
    +create table subsampling_table as
    +with stats as (
    +  select
    +    sum(freq) as numTrainWords
    +  FROM
    +    freq
    +)
    +select
    +  l.wordid,
    +  l.word,
    +  sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
    +from
    +  freq l
    +cross join
    +  stats r
    +;
    +```
    +
    +```sql
    +select * FROM subsampling_table order by p;
    +```
    +
    +| wordid | word | p |
    +|:----: | :----: |:----:|
    +| 48645 | the  | 0.04013665|
    +| 11245 | of   | 0.052463654|
    +| 16368 | and  | 0.06555538|
    +| 61938 | 00   | 0.068162076|
    +| 19977 | in   | 0.071441144|
    +| 83599 | 0    | 0.07528994|
    +| 95017 | a    | 0.07559573|
    +| 1225  | to   | 0.07953133|
    +| 37062 | 0000 | 0.08779001|
    +| 58246 | is   | 0.09049763|
    +|  ...  | ...  |... |
    +
    +The first row shows that 4% of `the` are used in the documents during training.
    +
    +# Delete low frequency words and high frequency words from `docs_words`
    +
    +To reduce useless words from corpus,
    +low frequency words and high frequency words are deleted.
    +And, to avoid loading long document on memory, a  document is split into some sub-documents.
    +
    +```sql
    +set hivevar:maxlength=1500;
    +SET hivevar:seed=31;
    +
    +drop table train_docs;
    +create table train_docs as
    +  with docs_exploded as (
    +    select
    +      docid,
    +      word,
    +      pos % ${maxlength} as pos,
    +      pos div ${maxlength} as splitid,
    +      rand(${seed}) as rnd
    +    from
    +      docs_words LATERAL VIEW posexplode(words) t as pos, word
    +  )
    +select
    +  l.docid,
    +  -- to_ordered_list(l.word, l.pos) as words
    +  to_ordered_list(r2.wordid, l.pos) as words,
    +from
    +  docs_exploded l
    +  LEFT SEMI join freq r on (l.word = r.word)
    +  join subsampling_table r2 on (l.word = r2.word)
    +where
    +  r2.p > l.rnd
    +group by
    +  l.docid, l.splitid
    +;
    +```
    +
    +If you store string word in `train_docs` table,
    +please replace `to_ordered_list(r2.wordid, l.pos) as words` with  `to_ordered_list(l.word, l.pos) as words`.
    +
    +# Create negative sampling table
    +
    +Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
    +Here, `negative_table` is used to store word sampling probability for negative sampling.
    +`z` is a hyperparameter of noise distribution for negative sampling.
    +During word2vec training,
    --- End diff --
    
    Line break is not needed. Line break after `,` is unreasonable (elsewhere as well).


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542877
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    --- End diff --
    
    remove unnecessary blank line


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13440173/badge)](https://coveralls.io/builds/13440173)
    
    Coverage decreased (-0.5%) to 40.392% when pulling **c34003804b5ef33bca6abe9113de6eb01b5e94c6 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546893
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
    +    }
    +
    +    protected void trainOnDoc(@Nonnull final int[] doc) {
    +        final int vecDim = dim;
    +        final int numNegative = neg;
    +        final PRNG _rnd = rnd;
    +        final Int2FloatOpenHashTable _S = S;
    +        final int[] _aliasWordId = aliasWordId;
    +        float label, gradient;
    +
    +        // reuse instance
    +        int windowSize, k, numContext, targetWord, inWord, positiveWord;
    +
    +        updateLearningRate();
    +
    +        int docLength = doc.length;
    --- End diff --
    
    `final int docLength`


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546522
  
    --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java ---
    @@ -0,0 +1,203 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.Int2IntOpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +
    +import java.util.List;
    +import java.util.ArrayList;
    +import java.util.Map;
    +import java.util.Queue;
    +import java.util.ArrayDeque;
    +
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
    +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnull;
    +
    +public final class AliasTableBuilderUDTF extends GenericUDTF {
    --- End diff --
    
    Add Javadoc comment for the class referring papers.
    
    ```
    <pre>
    - A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
    - A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
    </pre>
    ```


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141553131
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    --- End diff --
    
    I see.
    Does `longOpt` remain `learningRate` or remove this field?


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141551510
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    --- End diff --
    
    There is no reason.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543945
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    --- End diff --
    
    `PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI)` may return null.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13370459/badge)](https://coveralls.io/builds/13370459)
    
    Coverage decreased (-0.8%) to 40.165% when pulling **83198617bcf82634d39d715e33499f99945f2ebb on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13386706/badge)](https://coveralls.io/builds/13386706)
    
    Coverage decreased (-0.8%) to 40.14% when pulling **c7cba82a2eef2b5b32e67221a20c6cdb4570643a on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543095
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    --- End diff --
    
    add `@Nonnegative` for each constructor argument and caller methods.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13412359/badge)](https://coveralls.io/builds/13412359)
    
    Coverage decreased (-0.6%) to 40.343% when pulling **aede5ec0cf4d01780034c0d46c456486fecc1cb3 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13415337/badge)](https://coveralls.io/builds/13415337)
    
    Coverage decreased (-0.6%) to 40.31% when pulling **4abdb8f3d2632baa9b3ead928bc8bb0283027e20 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545135
  
    --- Diff: docs/gitbook/embedding/word2vec.md ---
    @@ -0,0 +1,399 @@
    +<!--
    +  Licensed to the Apache Software Foundation (ASF) under one
    +  or more contributor license agreements.  See the NOTICE file
    +  distributed with this work for additional information
    +  regarding copyright ownership.  The ASF licenses this file
    +  to you under the Apache License, Version 2.0 (the
    +  "License"); you may not use this file except in compliance
    +  with the License.  You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +  Unless required by applicable law or agreed to in writing,
    +  software distributed under the License is distributed on an
    +  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +  KIND, either express or implied.  See the License for the
    +  specific language governing permissions and limitations
    +  under the License.
    +-->
    +
    +Word Embedding is a powerful tool for many tasks,
    +e.g. finding similar words,
    +feature vectors for supervised machine learning task and word analogy,
    +such as `king - man + woman =~ queen`.
    +In word embedding,
    +each word represents a low dimension and dense vector.
    +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
    +
    +The papers introduce the method are as follows:
    +
    +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
    +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
    +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
    +
    +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
    +Hivemall enables you to train your sequence data such as,
    +but not limited to, documents based on word2vec.
    +This article gives usage instructions of the feature.
    +
    +<!-- toc -->
    +
    +> #### Note
    +> This feature is supported from Hivemall v0.5-rc.? or later.
    +
    +# Prepare document data
    +
    +Assume that you already have `docs` table which contains many documents as string format with unique index:
    +
    +```sql
    +select * FROM docs;
    +```
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
    +|  ...  | ... |
    +
    +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
    +
    +```sql
    +drop table docs_words;
    +create table docs_words as
    +  select
    +    docid,
    +    tokenize(doc, true) as words
    +  FROM
    +    docs
    +;
    +```
    +
    +This table shows tokenized document.
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
    +|  ...  | ... |
    +
    +Then, you count frequency up per word and remove low frequency words from the vocabulary.
    +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
    +
    +```sql
    +set hivevar:mincount=5;
    +
    +drop table freq;
    +create table freq as
    +select
    +  row_number() over () - 1 as wordid,
    +  word,
    +  freq
    +from (
    +  select
    +    word,
    +    COUNT(*) as freq
    +  from
    +    docs_words
    +  LATERAL VIEW explode(words) lTable as word
    +  group by
    +    word
    +) t
    +where freq >= ${mincount}
    +;
    +```
    +
    +Hivemall's word2vec supports two type words; string and int.
    +String type tends to use huge memory during training.
    +On the other hand, int type tends to use less memory.
    +If you train on small dataset, we recommend using string type,
    +because memory usage can be ignored and HiveQL is more simple.
    +If you train on large dataset, we recommend using int type,
    +because it saves memory during training.
    +
    +# Create sub-sampling table
    +
    +Sub-sampling table is stored a sub-sampling probability per word.
    +
    +The sub-sampling probability of word $$w_i$$ is computed by the following equation:
    +
    +$$
    +\begin{aligned}
    +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
    +\end{aligned}
    +$$
    +
    +During word2vec training,
    +not sub-sampled words are ignored.
    +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
    +The smaller `sample` value set,
    +the fewer words are used during training.
    +
    +```sql
    +set hivevar:sample=1e-4;
    +
    +drop table subsampling_table;
    +create table subsampling_table as
    +with stats as (
    +  select
    +    sum(freq) as numTrainWords
    +  FROM
    +    freq
    +)
    +select
    +  l.wordid,
    +  l.word,
    +  sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
    +from
    +  freq l
    +cross join
    +  stats r
    +;
    +```
    +
    +```sql
    +select * FROM subsampling_table order by p;
    +```
    +
    +| wordid | word | p |
    +|:----: | :----: |:----:|
    +| 48645 | the  | 0.04013665|
    +| 11245 | of   | 0.052463654|
    +| 16368 | and  | 0.06555538|
    +| 61938 | 00   | 0.068162076|
    +| 19977 | in   | 0.071441144|
    +| 83599 | 0    | 0.07528994|
    +| 95017 | a    | 0.07559573|
    +| 1225  | to   | 0.07953133|
    +| 37062 | 0000 | 0.08779001|
    +| 58246 | is   | 0.09049763|
    +|  ...  | ...  |... |
    +
    +The first row shows that 4% of `the` are used in the documents during training.
    +
    +# Delete low frequency words and high frequency words from `docs_words`
    +
    +To reduce useless words from corpus,
    +low frequency words and high frequency words are deleted.
    +And, to avoid loading long document on memory, a  document is split into some sub-documents.
    +
    +```sql
    +set hivevar:maxlength=1500;
    +SET hivevar:seed=31;
    +
    +drop table train_docs;
    +create table train_docs as
    +  with docs_exploded as (
    +    select
    +      docid,
    +      word,
    +      pos % ${maxlength} as pos,
    +      pos div ${maxlength} as splitid,
    +      rand(${seed}) as rnd
    +    from
    +      docs_words LATERAL VIEW posexplode(words) t as pos, word
    +  )
    +select
    +  l.docid,
    +  -- to_ordered_list(l.word, l.pos) as words
    +  to_ordered_list(r2.wordid, l.pos) as words,
    +from
    +  docs_exploded l
    +  LEFT SEMI join freq r on (l.word = r.word)
    +  join subsampling_table r2 on (l.word = r2.word)
    +where
    +  r2.p > l.rnd
    +group by
    +  l.docid, l.splitid
    +;
    +```
    +
    +If you store string word in `train_docs` table,
    +please replace `to_ordered_list(r2.wordid, l.pos) as words` with  `to_ordered_list(l.word, l.pos) as words`.
    +
    +# Create negative sampling table
    +
    +Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
    +Here, `negative_table` is used to store word sampling probability for negative sampling.
    +`z` is a hyperparameter of noise distribution for negative sampling.
    +During word2vec training,
    +words sampled from this distribution are used for negative examples.
    +Noise distribution is the unigram distribution raised to the 3/4rd power.
    +
    +$$
    +\begin{aligned}
    +p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
    +\end{aligned}
    +$$
    +
    +To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
    +Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
    +
    +This method has proposed in papers below:
    +
    +- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
    +- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
    +
    +```sql
    +set hivevar:z=3/4;
    +
    +drop table negative_table;
    +create table negative_table as
    +select
    +  collect_list(array(word, p, other)) as negative_table
    +from (
    +  select
    +    alias_table(to_map(word, negative)) as (word, p, other)
    +  from
    +    (
    +      select
    +        word,
    +        -- wordid as word,
    +        pow(freq, ${z}) as negative
    +      from
    +        freq
    +    ) t
    +) t1
    +;
    +```
    +
    +`alias_table` function returns the records like following.
    +
    +| word | p | other |
    +|:----: | :----: |:----:|
    +| leopold | 0.6556492 | 0000 |
    +| slep | 0.09060383 | leopold |
    +| valentinian | 0.76077825 | belarusian |
    +| slew | 0.90569097 | colin |
    +| lucien | 0.86329675 | overland |
    +| equitable | 0.7270946 | farms |
    +| insurers | 0.2367955 | israel |
    +| lucier | 0.14855136 | supplements |
    +| lieve | 0.12075222 | separatist |
    +| skyhawks | 0.14079945 | steamed |
    +| ... | ... | ... |
    +
    +To sample negative word from this `negative_table`,
    +
    +1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
    +2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
    +3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
    +
    +Here, to use it in training function of word2vec, 
    +`alias_table`'s return records are stored into one list in the `negative_table`.
    +
    +# Train word2vec
    +
    +Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
    +The default model is `"skipgram"`.
    +
    +> #### Note
    +> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
    +
    +## Train Skip-Gram
    +
    +In skip-gram model,
    +word vectors are trained to predict the nearby words.
    +For example, given a sentence like a `"alice", "was", "beginning", "to"`,
    +`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
    +
    +```sql
    +select sum(size(words)) from train_docs;
    +set hivevar:n=418953; -- previous query return value
    +
    +drop table skipgram;
    +create table skipgram as
    +select
    +  train_word2vec(
    +    r.negative_table,
    +    l.words,
    +    "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
    +  )
    +from
    +  train_docs l
    +  cross join negative_table r
    +;
    +```
    +
    +When word is treated as int istead of string,
    +you may need to transform wordid of int to word of string by `join` statement.
    +
    +```sql
    +drop table skipgram;
    +
    +create table skipgram as
    +select
    +  r.word, t.i, t.wi
    +from (
    +  select
    +    train_word2vec(
    +      r.negative_table,
    +      l.wordsint,
    +      "-n 418953 -win 5 -neg 15 -iter 5"
    --- End diff --
    
    `-n 418953` should be `-n ${n}`


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13429055/badge)](https://coveralls.io/builds/13429055)
    
    Coverage decreased (-0.6%) to 40.372% when pulling **8a42adf3687f8c823f255fcbd0c7ff7962f81f0b on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543757
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
    +    }
    +
    +    protected void trainOnDoc(@Nonnull final int[] doc) {
    +        final int vecDim = dim;
    +        final int numNegative = neg;
    +        final PRNG _rnd = rnd;
    +        final Int2FloatOpenHashTable _S = S;
    +        final int[] _aliasWordId = aliasWordId;
    +        float label, gradient;
    +
    +        // reuse instance
    +        int windowSize, k, numContext, targetWord, inWord, positiveWord;
    +
    +        updateLearningRate();
    +
    +        int docLength = doc.length;
    +        for (int t = 0; t < iter; t++) {
    +            for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) {
    +                windowSize = _rnd.nextInt(win) + 1;
    +
    +                numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize)
    +                        + Math.min(0, docLength - positiveWordPosition - windowSize - 1);
    +
    +                float[] gradVec = new float[vecDim];
    +                float[] averageVec = new float[vecDim];
    +
    +                // collect context words
    +                for (int contextPosition = positiveWordPosition - windowSize; contextPosition < positiveWordPosition
    +                        + windowSize + 1; contextPosition++) {
    +                    if (contextPosition == positiveWordPosition || contextPosition < 0
    +                            || contextPosition >= docLength) {
    +                        continue;
    +                    }
    +
    +                    inWord = doc[contextPosition];
    +
    +                    // average vector of input word vectors
    +                    if (!inputWeights.containsKey(inWord * vecDim)) {
    +                        initWordWeights(inWord);
    +                    }
    +
    +                    for (int i = 0; i < vecDim; i++) {
    +                        averageVec[i] += inputWeights.get(inWord * vecDim + i) / numContext;
    +                    }
    +                }
    +                positiveWord = doc[positiveWordPosition];
    +                // negative sampling
    +                for (int d = 0; d < numNegative + 1; d++) {
    +                    if (d == 0) {
    +                        targetWord = positiveWord;
    +                        label = 1.f;
    +                    } else {
    +                        do {
    +                            k = _rnd.nextInt(_S.size());
    +                            if (_S.get(k) > _rnd.nextDouble()) {
    +                                targetWord = k;
    +                            } else {
    +                                targetWord = _aliasWordId[k];
    +                            }
    +                        } while (targetWord == positiveWord);
    +                        label = 0.f;
    +                    }
    +
    +                    gradient = grad(label, averageVec, targetWord) * lr;
    +                    for (int i = 0; i < vecDim; i++) {
    +                        gradVec[i] += gradient * contextWeights.get(targetWord * vecDim + i);
    +                        contextWeights.put(targetWord * vecDim + i,
    +                            contextWeights.get(targetWord * vecDim + i) + gradient * averageVec[i]);
    +                    }
    +                }
    +
    +                // update inWord vector
    +                for (int contextPosition = positiveWordPosition - windowSize; contextPosition < positiveWordPosition
    +                        + windowSize + 1; contextPosition++) {
    +                    if (contextPosition == positiveWordPosition || contextPosition < 0
    +                            || contextPosition >= docLength) {
    +                        continue;
    +                    }
    +
    +                    inWord = doc[contextPosition];
    +                    for (int i = 0; i < vecDim; i++) {
    +                        inputWeights.put(inWord * vecDim + i, inputWeights.get(inWord * vecDim + i)
    +                                + gradVec[i]);
    +                    }
    +                }
    +            }
    +        }
    +
    +        wordCount += docLength * iter;
    +    }
    +
    +    private float grad(final float label, @Nonnull final float[] w, final int c) {
    +        float dotValue = 0.f;
    +        for (int i = 0; i < dim; i++) {
    +            dotValue += w[i] * contextWeights.get(c * dim + i);
    +        }
    +
    +        return (label - sigmoid(dotValue, sigmoidTable));
    --- End diff --
    
    remove redundant outermost `(`.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13472784/badge)](https://coveralls.io/builds/13472784)
    
    Coverage decreased (-0.6%) to 40.508% when pulling **0b163fade6f2d26ce918211c94a78c9a3b648cbe on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141550040
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    --- End diff --
    
    What's `10578`?


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r140413586
  
    --- Diff: core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.unsupervised;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +
    +public abstract class AbstractWord2vecModel {
    --- End diff --
    
    Please rename `Word2vec` to `Word2Vec` as seen in [spark](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala).


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545257
  
    --- Diff: docs/gitbook/embedding/word2vec.md ---
    @@ -0,0 +1,399 @@
    +<!--
    +  Licensed to the Apache Software Foundation (ASF) under one
    +  or more contributor license agreements.  See the NOTICE file
    +  distributed with this work for additional information
    +  regarding copyright ownership.  The ASF licenses this file
    +  to you under the Apache License, Version 2.0 (the
    +  "License"); you may not use this file except in compliance
    +  with the License.  You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +  Unless required by applicable law or agreed to in writing,
    +  software distributed under the License is distributed on an
    +  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +  KIND, either express or implied.  See the License for the
    +  specific language governing permissions and limitations
    +  under the License.
    +-->
    +
    +Word Embedding is a powerful tool for many tasks,
    +e.g. finding similar words,
    +feature vectors for supervised machine learning task and word analogy,
    +such as `king - man + woman =~ queen`.
    +In word embedding,
    +each word represents a low dimension and dense vector.
    +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
    +
    +The papers introduce the method are as follows:
    +
    +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
    +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
    +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
    +
    +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
    +Hivemall enables you to train your sequence data such as,
    +but not limited to, documents based on word2vec.
    +This article gives usage instructions of the feature.
    +
    +<!-- toc -->
    +
    +> #### Note
    +> This feature is supported from Hivemall v0.5-rc.? or later.
    +
    +# Prepare document data
    +
    +Assume that you already have `docs` table which contains many documents as string format with unique index:
    +
    +```sql
    +select * FROM docs;
    +```
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
    +|  ...  | ... |
    +
    +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
    +
    +```sql
    +drop table docs_words;
    +create table docs_words as
    +  select
    +    docid,
    +    tokenize(doc, true) as words
    +  FROM
    +    docs
    +;
    +```
    +
    +This table shows tokenized document.
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
    +|  ...  | ... |
    +
    +Then, you count frequency up per word and remove low frequency words from the vocabulary.
    +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
    +
    +```sql
    +set hivevar:mincount=5;
    +
    +drop table freq;
    +create table freq as
    +select
    +  row_number() over () - 1 as wordid,
    +  word,
    +  freq
    +from (
    +  select
    +    word,
    +    COUNT(*) as freq
    +  from
    +    docs_words
    +  LATERAL VIEW explode(words) lTable as word
    +  group by
    +    word
    +) t
    +where freq >= ${mincount}
    +;
    +```
    +
    +Hivemall's word2vec supports two type words; string and int.
    +String type tends to use huge memory during training.
    +On the other hand, int type tends to use less memory.
    +If you train on small dataset, we recommend using string type,
    +because memory usage can be ignored and HiveQL is more simple.
    +If you train on large dataset, we recommend using int type,
    +because it saves memory during training.
    +
    +# Create sub-sampling table
    +
    +Sub-sampling table is stored a sub-sampling probability per word.
    +
    +The sub-sampling probability of word $$w_i$$ is computed by the following equation:
    +
    +$$
    +\begin{aligned}
    +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
    +\end{aligned}
    +$$
    +
    +During word2vec training,
    +not sub-sampled words are ignored.
    +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
    +The smaller `sample` value set,
    +the fewer words are used during training.
    +
    +```sql
    +set hivevar:sample=1e-4;
    +
    +drop table subsampling_table;
    +create table subsampling_table as
    +with stats as (
    +  select
    +    sum(freq) as numTrainWords
    +  FROM
    +    freq
    +)
    +select
    +  l.wordid,
    +  l.word,
    +  sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
    +from
    +  freq l
    +cross join
    +  stats r
    +;
    +```
    +
    +```sql
    +select * FROM subsampling_table order by p;
    +```
    +
    +| wordid | word | p |
    +|:----: | :----: |:----:|
    +| 48645 | the  | 0.04013665|
    +| 11245 | of   | 0.052463654|
    +| 16368 | and  | 0.06555538|
    +| 61938 | 00   | 0.068162076|
    +| 19977 | in   | 0.071441144|
    +| 83599 | 0    | 0.07528994|
    +| 95017 | a    | 0.07559573|
    +| 1225  | to   | 0.07953133|
    +| 37062 | 0000 | 0.08779001|
    +| 58246 | is   | 0.09049763|
    +|  ...  | ...  |... |
    +
    +The first row shows that 4% of `the` are used in the documents during training.
    +
    +# Delete low frequency words and high frequency words from `docs_words`
    +
    +To reduce useless words from corpus,
    +low frequency words and high frequency words are deleted.
    +And, to avoid loading long document on memory, a  document is split into some sub-documents.
    +
    +```sql
    +set hivevar:maxlength=1500;
    +SET hivevar:seed=31;
    +
    +drop table train_docs;
    +create table train_docs as
    +  with docs_exploded as (
    +    select
    +      docid,
    +      word,
    +      pos % ${maxlength} as pos,
    +      pos div ${maxlength} as splitid,
    +      rand(${seed}) as rnd
    +    from
    +      docs_words LATERAL VIEW posexplode(words) t as pos, word
    +  )
    +select
    +  l.docid,
    +  -- to_ordered_list(l.word, l.pos) as words
    +  to_ordered_list(r2.wordid, l.pos) as words,
    +from
    +  docs_exploded l
    +  LEFT SEMI join freq r on (l.word = r.word)
    +  join subsampling_table r2 on (l.word = r2.word)
    +where
    +  r2.p > l.rnd
    +group by
    +  l.docid, l.splitid
    +;
    +```
    +
    +If you store string word in `train_docs` table,
    +please replace `to_ordered_list(r2.wordid, l.pos) as words` with  `to_ordered_list(l.word, l.pos) as words`.
    +
    +# Create negative sampling table
    +
    +Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
    +Here, `negative_table` is used to store word sampling probability for negative sampling.
    +`z` is a hyperparameter of noise distribution for negative sampling.
    +During word2vec training,
    +words sampled from this distribution are used for negative examples.
    +Noise distribution is the unigram distribution raised to the 3/4rd power.
    +
    +$$
    +\begin{aligned}
    +p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
    +\end{aligned}
    +$$
    +
    +To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
    +Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
    +
    +This method has proposed in papers below:
    +
    +- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
    +- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
    +
    +```sql
    +set hivevar:z=3/4;
    +
    +drop table negative_table;
    +create table negative_table as
    +select
    +  collect_list(array(word, p, other)) as negative_table
    +from (
    +  select
    +    alias_table(to_map(word, negative)) as (word, p, other)
    +  from
    +    (
    +      select
    +        word,
    +        -- wordid as word,
    +        pow(freq, ${z}) as negative
    +      from
    +        freq
    +    ) t
    +) t1
    +;
    +```
    +
    +`alias_table` function returns the records like following.
    +
    +| word | p | other |
    +|:----: | :----: |:----:|
    +| leopold | 0.6556492 | 0000 |
    +| slep | 0.09060383 | leopold |
    +| valentinian | 0.76077825 | belarusian |
    +| slew | 0.90569097 | colin |
    +| lucien | 0.86329675 | overland |
    +| equitable | 0.7270946 | farms |
    +| insurers | 0.2367955 | israel |
    +| lucier | 0.14855136 | supplements |
    +| lieve | 0.12075222 | separatist |
    +| skyhawks | 0.14079945 | steamed |
    +| ... | ... | ... |
    +
    +To sample negative word from this `negative_table`,
    +
    +1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
    +2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
    +3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
    +
    +Here, to use it in training function of word2vec, 
    +`alias_table`'s return records are stored into one list in the `negative_table`.
    +
    +# Train word2vec
    +
    +Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
    +The default model is `"skipgram"`.
    +
    +> #### Note
    +> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
    +
    +## Train Skip-Gram
    +
    +In skip-gram model,
    +word vectors are trained to predict the nearby words.
    +For example, given a sentence like a `"alice", "was", "beginning", "to"`,
    +`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
    +
    +```sql
    +select sum(size(words)) from train_docs;
    +set hivevar:n=418953; -- previous query return value
    +
    +drop table skipgram;
    +create table skipgram as
    +select
    +  train_word2vec(
    +    r.negative_table,
    +    l.words,
    +    "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
    +  )
    +from
    +  train_docs l
    +  cross join negative_table r
    +;
    +```
    +
    +When word is treated as int istead of string,
    +you may need to transform wordid of int to word of string by `join` statement.
    +
    +```sql
    +drop table skipgram;
    +
    +create table skipgram as
    +select
    +  r.word, t.i, t.wi
    +from (
    +  select
    +    train_word2vec(
    +      r.negative_table,
    +      l.wordsint,
    +      "-n 418953 -win 5 -neg 15 -iter 5"
    +    ) as (wordid, i, wi)
    +  from
    +    train_docs l
    +  cross join
    +    negative_table r
    +) t
    +join freq r on (t.wordid = r.wordid)
    +;
    +```
    +
    +## Train CBoW
    +
    +In CBoW model,
    --- End diff --
    
    remove line break after `,`


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546846
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    --- End diff --
    
    add a blank line before constructor.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543245
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.inputWeights.defaultReturnValue(0.f);
    +        this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.contextWeights.defaultReturnValue(0.f);
    +    }
    +
    +    private static float[] initSigmoidTable() {
    +        float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE];
    +        for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) {
    +            float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID;
    +            sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
    +        }
    +        return sigmoidTable;
    +    }
    +
    +    protected void initWordWeights(final int wordId) {
    +        for (int i = 0; i < dim; i++) {
    +            inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim);
    +        }
    +    }
    +
    +    protected static float sigmoid(final float v, final float[] sigmoidTable) {
    +        if (v > MAX_SIGMOID) {
    +            return 1.f;
    +        } else if (v < -MAX_SIGMOID) {
    +            return 0.f;
    +        } else {
    +            return sigmoidTable[(int) ((v + MAX_SIGMOID) * (SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2))];
    +        }
    +    }
    +
    +    protected void updateLearningRate() {
    +        // TODO: valid lr?
    --- End diff --
    
    remove this TODO comment and blank lines.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547708
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    +            "learningRate",
    +            true,
    +            "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
    +
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        CommandLine cl = null;
    +        int win = 5;
    +        int neg = 5;
    +        int iter = 5;
    +        int dim = 100;
    +        long numTrainWords = 0L;
    +        String modelName = "skipgram";
    +        float lr = 0.025f;
    +
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
    +            if (numTrainWords <= 0) {
    +                throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
    +                        + numTrainWords);
    +            }
    +
    +            dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
    +            if (dim <= 0.d) {
    +                throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
    +            }
    +
    +            win = Primitives.parseInt(cl.getOptionValue("win"), win);
    +            if (win <= 0) {
    +                throw new UDFArgumentException("Argument `int win` must be positive: " + win);
    +            }
    +
    +            neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
    +            if (neg < 0) {
    +                throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
    +            }
    +
    +            iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
    +            if (iter <= 0) {
    +                throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
    +            }
    +
    +            modelName = cl.getOptionValue("model", modelName);
    +            if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
    --- End diff --
    
    `"skipgram".equals(modelName)` is null safe.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13431616/badge)](https://coveralls.io/builds/13431616)
    
    Coverage decreased (-0.9%) to 40.049% when pulling **f19d732122d29a00e421d7f1ad0d0ca93f242c1e on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13452116/badge)](https://coveralls.io/builds/13452116)
    
    Coverage decreased (-0.9%) to 40.065% when pulling **2415589bb3a8eda3e23a765b10e01fde6c70a298 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543209
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.inputWeights.defaultReturnValue(0.f);
    +        this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.contextWeights.defaultReturnValue(0.f);
    +    }
    +
    +    private static float[] initSigmoidTable() {
    +        float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE];
    +        for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) {
    +            float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID;
    +            sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
    +        }
    +        return sigmoidTable;
    +    }
    +
    +    protected void initWordWeights(final int wordId) {
    +        for (int i = 0; i < dim; i++) {
    +            inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim);
    +        }
    +    }
    +
    +    protected static float sigmoid(final float v, final float[] sigmoidTable) {
    --- End diff --
    
    `@Nonnull` for sigmoidTable


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13372775/badge)](https://coveralls.io/builds/13372775)
    
    Coverage decreased (-0.8%) to 40.149% when pulling **39d11236d100a92d54cc46d8ffce4bd89670217f on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13366601/badge)](https://coveralls.io/builds/13366601)
    
    Coverage decreased (-0.7%) to 40.181% when pulling **7a5fd547caef5d1af512422dc75dd0efdf5b9466 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    `What type of PR is it? => Improvement` should be `Feature`.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13450431/badge)](https://coveralls.io/builds/13450431)
    
    Coverage decreased (-0.9%) to 40.065% when pulling **d12ba32469a35653ee6a499157ad236bfcaa21cf on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545448
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    +            "learningRate",
    +            true,
    +            "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
    +
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        CommandLine cl = null;
    +        int win = 5;
    +        int neg = 5;
    +        int iter = 5;
    +        int dim = 100;
    +        long numTrainWords = 0L;
    +        String modelName = "skipgram";
    +        float lr = 0.025f;
    +
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
    +            if (numTrainWords <= 0) {
    +                throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
    +                        + numTrainWords);
    +            }
    +
    +            dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
    +            if (dim <= 0.d) {
    +                throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
    +            }
    +
    +            win = Primitives.parseInt(cl.getOptionValue("win"), win);
    +            if (win <= 0) {
    +                throw new UDFArgumentException("Argument `int win` must be positive: " + win);
    +            }
    +
    +            neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
    +            if (neg < 0) {
    +                throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
    +            }
    +
    +            iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
    +            if (iter <= 0) {
    +                throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
    +            }
    +
    +            modelName = cl.getOptionValue("model", modelName);
    +            if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
    +                throw new UDFArgumentException("Argument `string model` must be skipgram or cbow: "
    +                        + modelName);
    +            }
    +
    +            if (modelName.equals("cbow")) {
    +                lr = 0.05f;
    +            }
    +
    +            lr = Primitives.parseFloat(cl.getOptionValue("lr"), lr);
    +            if (lr <= 0.f) {
    +                throw new UDFArgumentException("Argument `float lr` must be positive: " + lr);
    +            }
    +        }
    +
    +        this.numTrainWords = numTrainWords;
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.skipgram = modelName.equals("skipgram");
    +        this.startingLR = lr;
    +        return cl;
    +    }
    +
    +    public void close() throws HiveException {
    +        if (model != null) {
    +            forwardModel();
    +            this.model = null;
    +            this.word2index = null;
    +            this.S = null;
    +        }
    +    }
    +
    +    private void forwardModel() throws HiveException {
    +        if (isStringInput) {
    +            final Text word = new Text();
    +            final IntWritable dimIndex = new IntWritable();
    +            final FloatWritable value = new FloatWritable();
    +
    +            final Object[] result = new Object[3];
    +            result[0] = word;
    +            result[1] = dimIndex;
    +            result[2] = value;
    +
    +            IMapIterator<String, Integer> iter = word2index.entries();
    +            while (iter.next() != -1) {
    +                int wordId = iter.getValue();
    +                if (!model.inputWeights.containsKey(wordId * dim)){
    +                    continue;
    +                }
    +
    +                word.set(iter.getKey());
    +
    +                for (int i = 0; i < dim; i++) {
    +                    dimIndex.set(i);
    +                    value.set(model.inputWeights.get(wordId * dim + i));
    +                    forward(result);
    +                }
    +            }
    +        } else {
    +            final IntWritable word = new IntWritable();
    +            final IntWritable dimIndex = new IntWritable();
    +            final FloatWritable value = new FloatWritable();
    +
    +            final Object[] result = new Object[3];
    +            result[0] = word;
    +            result[1] = dimIndex;
    +            result[2] = value;
    +
    +            for (int wordId = 0; wordId < aliasWordIds.length; wordId++) {
    +                if (!model.inputWeights.containsKey(wordId * dim)){
    +                    break;
    +                }
    +                word.set(wordId);
    +                for (int i = 0; i < dim; i++) {
    +                    dimIndex.set(i);
    +                    value.set(model.inputWeights.get(wordId * dim + i));
    +                    forward(result);
    +                }
    +            }
    +        }
    +    }
    +
    +    private int getWordId(@Nonnull final String word) {
    +        if (word2index.containsKey(word)) {
    +            return word2index.get(word);
    +        } else {
    +            int w = word2index.size();
    +            word2index.put(word, w);
    +            return w;
    +        }
    +    }
    +
    +    private void parseNegativeTable(@Nonnull Object listObj) {
    +        int aliasSize = negativeTableOI.getListLength(listObj);
    +        Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(aliasSize);
    +        int[] aliasWordIds = new int[aliasSize];
    +
    --- End diff --
    
    put `final` to ^ 3 local variables unchanged in the loop.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556886
  
    --- Diff: core/src/main/java/hivemall/embedding/SkipGramModel.java ---
    @@ -0,0 +1,119 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class SkipGramModel extends AbstractWord2VecModel {
    +    protected SkipGramModel(final int dim, final int win, final int neg, final int iter,
    --- End diff --
    
    Lot's of hyperparameters in constructor.
    
    Consider using Hyperparameter class as seen in https://github.com/apache/incubator-hivemall/blob/master/core/src/main/java/hivemall/fm/FMHyperParameters.java


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542968
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    --- End diff --
    
    `@Nonnegative` for each variable (win, neg, iter).


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556621
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    --- End diff --
    
    remain `learningRate` for longOpt and use `eta0` for initialLearningRate.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by nzw0301 <gi...@git.apache.org>.
Github user nzw0301 commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    @myui I resolved conflicts.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13449529/badge)](https://coveralls.io/builds/13449529)
    
    Coverage decreased (-0.9%) to 40.061% when pulling **af5b5becb55e58fd9db355be218035518973844a on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13387498/badge)](https://coveralls.io/builds/13387498)
    
    Coverage decreased (-0.8%) to 40.077% when pulling **bbdb561cc1bf3194128034fe194555bb6c167144 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13429656/badge)](https://coveralls.io/builds/13429656)
    
    Coverage decreased (-0.6%) to 40.372% when pulling **d1b4270861c277109f35c8675e8297ef081f1dee on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13454878/badge)](https://coveralls.io/builds/13454878)
    
    Coverage decreased (-0.5%) to 40.383% when pulling **da564b8cea0bd028d3f822ed750513cd28ff45c7 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556391
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    --- End diff --
    
    2^n or 1024 * 10 is more understandable.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r140725317
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,117 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected final int MAX_SIGMOID = 6;
    +    protected final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +    @Nonnegative
    +    private long wordCountActual;
    +
    +    @Nonnegative
    +    protected int dim;
    +    private PRNG _rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +
    +    protected AbstractWord2VecModel(final int dim, final float startingLR, final long numTrainWords) {
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.wordCountActual = 0L;
    +        this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable(MAX_SIGMOID, SIGMOID_TABLE_SIZE);
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.inputWeights.defaultReturnValue(-0.f);
    +        this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.contextWeights.defaultReturnValue(0.f);
    +    }
    +
    +    private static float[] initSigmoidTable(final int maxSigmoid, final int sigmoidTableSize) {
    +        float[] sigmoidTable = new float[sigmoidTableSize];
    +        for (int i = 0; i < sigmoidTableSize; i++) {
    +            float x = ((float) i / sigmoidTableSize * 2 - 1) * (float) maxSigmoid;
    +            sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f);
    +        }
    +        return sigmoidTable;
    +    }
    +
    +    protected void initWordWeights(final int wordId) {
    +        for (int i = 0; i < dim; i++) {
    +            inputWeights.put(wordId * dim + i, ((float) _rnd.nextDouble() - 0.5f) / dim);
    +        }
    +    }
    +
    +    protected static float sigmoid(final float v, final int MAX_SIGMOID,
    --- End diff --
    
    No need to use Constants for argument: `final int MAX_SIGMOID, final int SIGMOID_TABLE_SIZE`


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13411817/badge)](https://coveralls.io/builds/13411817)
    
    Coverage decreased (-0.9%) to 40.029% when pulling **e0945527c68da9e2c9ab6eb86a7eb7d66bf42aa7 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543643
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
    +    }
    +
    +    protected void trainOnDoc(@Nonnull final int[] doc) {
    +        final int vecDim = dim;
    +        final int numNegative = neg;
    +        final PRNG _rnd = rnd;
    +        final Int2FloatOpenHashTable _S = S;
    +        final int[] _aliasWordId = aliasWordId;
    +        float label, gradient;
    +
    +        // reuse instance
    +        int windowSize, k, numContext, targetWord, inWord, positiveWord;
    +
    +        updateLearningRate();
    +
    +        int docLength = doc.length;
    +        for (int t = 0; t < iter; t++) {
    +            for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) {
    +                windowSize = _rnd.nextInt(win) + 1;
    +
    +                numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize)
    +                        + Math.min(0, docLength - positiveWordPosition - windowSize - 1);
    +
    +                float[] gradVec = new float[vecDim];
    --- End diff --
    
    add `final` for `gradVec` and `averageVec`.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13472440/badge)](https://coveralls.io/builds/13472440)
    
    Coverage decreased (-0.6%) to 40.505% when pulling **8696f5ff668adf758d3545bab5885e51ce7d053e on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543514
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
    +    }
    +
    +    protected void trainOnDoc(@Nonnull final int[] doc) {
    +        final int vecDim = dim;
    +        final int numNegative = neg;
    +        final PRNG _rnd = rnd;
    +        final Int2FloatOpenHashTable _S = S;
    --- End diff --
    
    Member variable should be `_S` and local variable should be `S`.
    
    `_rnd`, `_aliasWordId` as well.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546656
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected static final int MAX_SIGMOID = 6;
    +    protected static final int SIGMOID_TABLE_SIZE = 1000;
    +    protected float[] sigmoidTable;
    +
    +
    +    @Nonnegative
    +    protected int dim;
    +    protected int win;
    +    protected int neg;
    +    protected int iter;
    +
    +    // learning rate parameters
    +    @Nonnegative
    +    protected float lr;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    @Nonnegative
    +    protected long wordCount;
    +    @Nonnegative
    +    private long lastWordCount;
    +
    +    protected PRNG rnd;
    +
    +    protected Int2FloatOpenHashTable contextWeights;
    +    protected Int2FloatOpenHashTable inputWeights;
    +    protected Int2FloatOpenHashTable S;
    +    protected int[] aliasWordId;
    +
    +    protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S,
    +            final int[] aliasWordId) {
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.startingLR = this.lr = startingLR;
    +        this.numTrainWords = numTrainWords;
    +
    +        // alias sampler for negative sampling
    +        this.S = S;
    +        this.aliasWordId = aliasWordId;
    +
    +        this.wordCount = 0L;
    +        this.lastWordCount = 0L;
    +        this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);
    +
    +        this.sigmoidTable = initSigmoidTable();
    +
    +        // TODO how to estimate size
    +        this.inputWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.inputWeights.defaultReturnValue(0.f);
    +        this.contextWeights = new Int2FloatOpenHashTable(10578 * dim);
    +        this.contextWeights.defaultReturnValue(0.f);
    +    }
    +
    +    private static float[] initSigmoidTable() {
    --- End diff --
    
    `@Nonnull` for return


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13415230/badge)](https://coveralls.io/builds/13415230)
    
    Coverage decreased (-0.9%) to 40.028% when pulling **4abdb8f3d2632baa9b3ead928bc8bb0283027e20 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547506
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    --- End diff --
    
    consistent naming `eta0`, `learningRate` for the initial learning rate.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545805
  
    --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java ---
    @@ -0,0 +1,203 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.Int2IntOpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +
    +import java.util.List;
    +import java.util.ArrayList;
    +import java.util.Map;
    +import java.util.Queue;
    +import java.util.ArrayDeque;
    +
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
    +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnull;
    +
    +public final class AliasTableBuilderUDTF extends GenericUDTF {
    +    private MapObjectInspector negativeTableOI;
    +    private PrimitiveObjectInspector negativeTableKeyOI;
    +    private PrimitiveObjectInspector negativeTableValueOI;
    +
    +    private int numVocab;
    +    private List<String> index2word;
    +    private Int2IntOpenHashTable A;
    +    private Int2FloatOpenHashTable S;
    +    private boolean isIntElement;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        if (!(argOIs.length >= 1)) {
    +            throw new UDFArgumentException(
    +                "_FUNC_(map<string, double>) takes at least one argument");
    +        }
    +
    +        this.negativeTableOI = HiveUtils.asMapOI(argOIs[0]);
    +        this.negativeTableValueOI = HiveUtils.asFloatingPointOI(negativeTableOI.getMapValueObjectInspector());
    +
    +        boolean isIntEmelentOI = HiveUtils.isIntOI((negativeTableOI.getMapKeyObjectInspector()));
    +
    +        if (isIntEmelentOI) {
    +            this.negativeTableKeyOI = HiveUtils.asIntCompatibleOI(negativeTableOI.getMapKeyObjectInspector());
    +        } else {
    +            this.negativeTableKeyOI = HiveUtils.asStringOI(negativeTableOI.getMapKeyObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +        fieldNames.add("word");
    +
    +        if (isIntEmelentOI) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        }
    +
    +        fieldNames.add("p");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        fieldNames.add("other");
    +        if (isIntEmelentOI) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        }
    +
    +        this.isIntElement = isIntEmelentOI;
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (!isIntElement) {
    +            index2word = new ArrayList<>();
    +        }
    +
    +        final List<Float> unnormalizedProb = new ArrayList<>();
    +        int numVocab = 0;
    +        float denom = 0.f;
    +        for (Map.Entry<?, ?> entry : negativeTableOI.getMap(args[0]).entrySet()) {
    +            if (!isIntElement) {
    +                String word = PrimitiveObjectInspectorUtils.getString(entry.getKey(),
    +                    negativeTableKeyOI);
    +                index2word.add(word);
    +            }
    +
    +            float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), negativeTableValueOI);
    +            unnormalizedProb.add(v);
    +            denom += v;
    +            numVocab++;
    +        }
    +
    +        this.numVocab = numVocab;
    +        createAliasTable(numVocab, denom, unnormalizedProb);
    +    }
    +
    +    private void createAliasTable(final int V, final float denom,
    +            final @Nonnull List<Float> unnormalizedProb) {
    +        Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(V);
    +        Int2IntOpenHashTable A = new Int2IntOpenHashTable(V);
    +
    +        final Queue<Integer> higherBin = new ArrayDeque<>();
    +        final Queue<Integer> lowerBin = new ArrayDeque<>();
    +
    +        for (int i = 0; i < V; i++) {
    +            float v = V * unnormalizedProb.get(i) / denom;
    +            S.put(i, v);
    +            if (v > 1.f) {
    +                higherBin.add(i);
    +            } else {
    +                lowerBin.add(i);
    +            }
    +        }
    +
    +        while (lowerBin.size() > 0 && higherBin.size() > 0) {
    +            int low = lowerBin.remove();
    +            int high = higherBin.remove();
    +            A.put(low, high);
    +            S.put(high, S.get(high) - 1.f + S.get(low));
    +            if (S.get(high) < 1.f) {
    +                lowerBin.add(high);
    +            } else {
    +                higherBin.add(high);
    +            }
    +        }
    +        this.A = A;
    +        this.S = S;
    +    }
    +
    +    @Override
    +    public void close() throws HiveException {
    +        if (isIntElement) {
    +            IntWritable word = new IntWritable();
    +            FloatWritable pro = new FloatWritable();
    +            IntWritable otherWord = new IntWritable();
    +
    +            Object[] res = new Object[3];
    +            res[0] = word;
    +            res[1] = pro;
    +            res[2] = otherWord;
    +
    +            for (int i = 0; i < numVocab; i++) {
    +                word.set(i);
    +                pro.set(S.get(i));
    +                if (A.get(i) == -1) {
    +                    otherWord.set(0);
    +                } else {
    +                    otherWord.set(A.get(i));
    +                }
    +                forward(res);
    +            }
    +        } else {
    +            Text word = new Text();
    +            FloatWritable pro = new FloatWritable();
    +            Text otherWord = new Text();
    +
    +            Object[] res = new Object[3];
    --- End diff --
    
    put `final` for local variables that are unchanged in the for loop.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13365085/badge)](https://coveralls.io/builds/13365085)
    
    Coverage decreased (-0.7%) to 40.18% when pulling **7014e8552f81d59ab35c95d8fcf54c56c24ba2c9 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547369
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    --- End diff --
    
    consistent naming `"iters", "iterations"` as seen in SLIM.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544782
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    +            }
    +        }
    +
    +        model.trainOnDoc(doc);
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("n", "numTrainWords", true,
    +            "The number of words in the documents. It is used to update learning rate");
    +        opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]");
    +        opts.addOption("win", "window", true, "Context window size [default: 5]");
    +        opts.addOption("neg", "negative", true,
    +            "The number of negative sampled words per word [default: 5]");
    +        opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
    +        opts.addOption("model", "modelName", true,
    +            "The model name of word2vec: skipgram or cbow [default: skipgram]");
    +        opts.addOption(
    +            "lr",
    +            "learningRate",
    +            true,
    +            "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]");
    +
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        CommandLine cl = null;
    +        int win = 5;
    +        int neg = 5;
    +        int iter = 5;
    +        int dim = 100;
    +        long numTrainWords = 0L;
    +        String modelName = "skipgram";
    +        float lr = 0.025f;
    +
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords);
    +            if (numTrainWords <= 0) {
    +                throw new UDFArgumentException("Argument `int numTrainWords` must be positive: "
    +                        + numTrainWords);
    +            }
    +
    +            dim = Primitives.parseInt(cl.getOptionValue("dim"), dim);
    +            if (dim <= 0.d) {
    +                throw new UDFArgumentException("Argument `int dim` must be positive: " + dim);
    +            }
    +
    +            win = Primitives.parseInt(cl.getOptionValue("win"), win);
    +            if (win <= 0) {
    +                throw new UDFArgumentException("Argument `int win` must be positive: " + win);
    +            }
    +
    +            neg = Primitives.parseInt(cl.getOptionValue("neg"), neg);
    +            if (neg < 0) {
    +                throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg);
    +            }
    +
    +            iter = Primitives.parseInt(cl.getOptionValue("iter"), iter);
    +            if (iter <= 0) {
    +                throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter);
    +            }
    +
    +            modelName = cl.getOptionValue("model", modelName);
    +            if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
    +                throw new UDFArgumentException("Argument `string model` must be skipgram or cbow: "
    +                        + modelName);
    +            }
    +
    +            if (modelName.equals("cbow")) {
    +                lr = 0.05f;
    +            }
    +
    +            lr = Primitives.parseFloat(cl.getOptionValue("lr"), lr);
    +            if (lr <= 0.f) {
    +                throw new UDFArgumentException("Argument `float lr` must be positive: " + lr);
    +            }
    +        }
    +
    +        this.numTrainWords = numTrainWords;
    +        this.win = win;
    +        this.neg = neg;
    +        this.iter = iter;
    +        this.dim = dim;
    +        this.skipgram = modelName.equals("skipgram");
    +        this.startingLR = lr;
    +        return cl;
    +    }
    +
    +    public void close() throws HiveException {
    +        if (model != null) {
    +            forwardModel();
    +            this.model = null;
    +            this.word2index = null;
    +            this.S = null;
    +        }
    +    }
    +
    +    private void forwardModel() throws HiveException {
    +        if (isStringInput) {
    +            final Text word = new Text();
    +            final IntWritable dimIndex = new IntWritable();
    +            final FloatWritable value = new FloatWritable();
    +
    +            final Object[] result = new Object[3];
    +            result[0] = word;
    +            result[1] = dimIndex;
    +            result[2] = value;
    +
    +            IMapIterator<String, Integer> iter = word2index.entries();
    +            while (iter.next() != -1) {
    +                int wordId = iter.getValue();
    +                if (!model.inputWeights.containsKey(wordId * dim)){
    +                    continue;
    +                }
    +
    +                word.set(iter.getKey());
    +
    +                for (int i = 0; i < dim; i++) {
    +                    dimIndex.set(i);
    +                    value.set(model.inputWeights.get(wordId * dim + i));
    +                    forward(result);
    +                }
    +            }
    +        } else {
    +            final IntWritable word = new IntWritable();
    +            final IntWritable dimIndex = new IntWritable();
    +            final FloatWritable value = new FloatWritable();
    +
    +            final Object[] result = new Object[3];
    +            result[0] = word;
    +            result[1] = dimIndex;
    +            result[2] = value;
    +
    +            for (int wordId = 0; wordId < aliasWordIds.length; wordId++) {
    +                if (!model.inputWeights.containsKey(wordId * dim)){
    +                    break;
    +                }
    +                word.set(wordId);
    +                for (int i = 0; i < dim; i++) {
    +                    dimIndex.set(i);
    +                    value.set(model.inputWeights.get(wordId * dim + i));
    +                    forward(result);
    +                }
    +            }
    +        }
    +    }
    +
    +    private int getWordId(@Nonnull final String word) {
    +        if (word2index.containsKey(word)) {
    --- End diff --
    
    `word2index` is not ensured to be non-null.
    
    ```java
    private static int getWordId(@Nonnull final String word, @CheckNotNull OpenHashTable<String, Integer> word2Index) {
       Precondition.checkNotNull(word2index);
    ```


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13411195/badge)](https://coveralls.io/builds/13411195)
    
    Coverage decreased (-0.5%) to 40.383% when pulling **7a2f4dbfeb89eee78e6e0526027b5b0ba9162c29 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13438117/badge)](https://coveralls.io/builds/13438117)
    
    Coverage decreased (-0.8%) to 40.107% when pulling **f0abd4fd99dcace050d155533ebc2dc8768cfc79 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545219
  
    --- Diff: docs/gitbook/embedding/word2vec.md ---
    @@ -0,0 +1,399 @@
    +<!--
    +  Licensed to the Apache Software Foundation (ASF) under one
    +  or more contributor license agreements.  See the NOTICE file
    +  distributed with this work for additional information
    +  regarding copyright ownership.  The ASF licenses this file
    +  to you under the Apache License, Version 2.0 (the
    +  "License"); you may not use this file except in compliance
    +  with the License.  You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +  Unless required by applicable law or agreed to in writing,
    +  software distributed under the License is distributed on an
    +  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +  KIND, either express or implied.  See the License for the
    +  specific language governing permissions and limitations
    +  under the License.
    +-->
    +
    +Word Embedding is a powerful tool for many tasks,
    +e.g. finding similar words,
    +feature vectors for supervised machine learning task and word analogy,
    +such as `king - man + woman =~ queen`.
    +In word embedding,
    +each word represents a low dimension and dense vector.
    +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
    +
    +The papers introduce the method are as follows:
    +
    +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
    +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
    +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
    +
    +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
    +Hivemall enables you to train your sequence data such as,
    +but not limited to, documents based on word2vec.
    +This article gives usage instructions of the feature.
    +
    +<!-- toc -->
    +
    +> #### Note
    +> This feature is supported from Hivemall v0.5-rc.? or later.
    +
    +# Prepare document data
    +
    +Assume that you already have `docs` table which contains many documents as string format with unique index:
    +
    +```sql
    +select * FROM docs;
    +```
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
    +|  ...  | ... |
    +
    +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
    +
    +```sql
    +drop table docs_words;
    +create table docs_words as
    +  select
    +    docid,
    +    tokenize(doc, true) as words
    +  FROM
    +    docs
    +;
    +```
    +
    +This table shows tokenized document.
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
    +|  ...  | ... |
    +
    +Then, you count frequency up per word and remove low frequency words from the vocabulary.
    +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
    +
    +```sql
    +set hivevar:mincount=5;
    +
    +drop table freq;
    +create table freq as
    +select
    +  row_number() over () - 1 as wordid,
    +  word,
    +  freq
    +from (
    +  select
    +    word,
    +    COUNT(*) as freq
    +  from
    +    docs_words
    +  LATERAL VIEW explode(words) lTable as word
    +  group by
    +    word
    +) t
    +where freq >= ${mincount}
    +;
    +```
    +
    +Hivemall's word2vec supports two type words; string and int.
    +String type tends to use huge memory during training.
    +On the other hand, int type tends to use less memory.
    +If you train on small dataset, we recommend using string type,
    +because memory usage can be ignored and HiveQL is more simple.
    +If you train on large dataset, we recommend using int type,
    +because it saves memory during training.
    +
    +# Create sub-sampling table
    +
    +Sub-sampling table is stored a sub-sampling probability per word.
    +
    +The sub-sampling probability of word $$w_i$$ is computed by the following equation:
    +
    +$$
    +\begin{aligned}
    +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
    +\end{aligned}
    +$$
    +
    +During word2vec training,
    +not sub-sampled words are ignored.
    +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words.
    +The smaller `sample` value set,
    +the fewer words are used during training.
    +
    +```sql
    +set hivevar:sample=1e-4;
    +
    +drop table subsampling_table;
    +create table subsampling_table as
    +with stats as (
    +  select
    +    sum(freq) as numTrainWords
    +  FROM
    +    freq
    +)
    +select
    +  l.wordid,
    +  l.word,
    +  sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p
    +from
    +  freq l
    +cross join
    +  stats r
    +;
    +```
    +
    +```sql
    +select * FROM subsampling_table order by p;
    +```
    +
    +| wordid | word | p |
    +|:----: | :----: |:----:|
    +| 48645 | the  | 0.04013665|
    +| 11245 | of   | 0.052463654|
    +| 16368 | and  | 0.06555538|
    +| 61938 | 00   | 0.068162076|
    +| 19977 | in   | 0.071441144|
    +| 83599 | 0    | 0.07528994|
    +| 95017 | a    | 0.07559573|
    +| 1225  | to   | 0.07953133|
    +| 37062 | 0000 | 0.08779001|
    +| 58246 | is   | 0.09049763|
    +|  ...  | ...  |... |
    +
    +The first row shows that 4% of `the` are used in the documents during training.
    +
    +# Delete low frequency words and high frequency words from `docs_words`
    +
    +To reduce useless words from corpus,
    +low frequency words and high frequency words are deleted.
    +And, to avoid loading long document on memory, a  document is split into some sub-documents.
    +
    +```sql
    +set hivevar:maxlength=1500;
    +SET hivevar:seed=31;
    +
    +drop table train_docs;
    +create table train_docs as
    +  with docs_exploded as (
    +    select
    +      docid,
    +      word,
    +      pos % ${maxlength} as pos,
    +      pos div ${maxlength} as splitid,
    +      rand(${seed}) as rnd
    +    from
    +      docs_words LATERAL VIEW posexplode(words) t as pos, word
    +  )
    +select
    +  l.docid,
    +  -- to_ordered_list(l.word, l.pos) as words
    +  to_ordered_list(r2.wordid, l.pos) as words,
    +from
    +  docs_exploded l
    +  LEFT SEMI join freq r on (l.word = r.word)
    +  join subsampling_table r2 on (l.word = r2.word)
    +where
    +  r2.p > l.rnd
    +group by
    +  l.docid, l.splitid
    +;
    +```
    +
    +If you store string word in `train_docs` table,
    +please replace `to_ordered_list(r2.wordid, l.pos) as words` with  `to_ordered_list(l.word, l.pos) as words`.
    +
    +# Create negative sampling table
    +
    +Negative sampling is an approximate function of [softmax function](https://en.wikipedia.org/wiki/Softmax_function).
    +Here, `negative_table` is used to store word sampling probability for negative sampling.
    +`z` is a hyperparameter of noise distribution for negative sampling.
    +During word2vec training,
    +words sampled from this distribution are used for negative examples.
    +Noise distribution is the unigram distribution raised to the 3/4rd power.
    +
    +$$
    +\begin{aligned}
    +p(w_i) = \frac{freq(w_i)^{\mathrm{z}}}{\sum freq(w)^{\mathrm{z}}}
    +\end{aligned}
    +$$
    +
    +To avoid using huge memory space for negative sampling like original implementation and remain to sample fastly from this distribution,
    +Hivemall uses [Alias method](https://en.wikipedia.org/wiki/Alias_method).
    +
    +This method has proposed in papers below:
    +
    +- A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974.
    +- A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977.
    +
    +```sql
    +set hivevar:z=3/4;
    +
    +drop table negative_table;
    +create table negative_table as
    +select
    +  collect_list(array(word, p, other)) as negative_table
    +from (
    +  select
    +    alias_table(to_map(word, negative)) as (word, p, other)
    +  from
    +    (
    +      select
    +        word,
    +        -- wordid as word,
    +        pow(freq, ${z}) as negative
    +      from
    +        freq
    +    ) t
    +) t1
    +;
    +```
    +
    +`alias_table` function returns the records like following.
    +
    +| word | p | other |
    +|:----: | :----: |:----:|
    +| leopold | 0.6556492 | 0000 |
    +| slep | 0.09060383 | leopold |
    +| valentinian | 0.76077825 | belarusian |
    +| slew | 0.90569097 | colin |
    +| lucien | 0.86329675 | overland |
    +| equitable | 0.7270946 | farms |
    +| insurers | 0.2367955 | israel |
    +| lucier | 0.14855136 | supplements |
    +| lieve | 0.12075222 | separatist |
    +| skyhawks | 0.14079945 | steamed |
    +| ... | ... | ... |
    +
    +To sample negative word from this `negative_table`,
    +
    +1. Sample record int index `i` from $$[0 \ldots \mathrm{num\_alias\_table\_records}]$$.
    +2. Sample float value `r` from $$[0.0 \ldots 1.0]$$ .
    +3. If `r` < `p` of `i` th record, return `word` `i` th record, else return `other` of `i` th record.
    +
    +Here, to use it in training function of word2vec, 
    +`alias_table`'s return records are stored into one list in the `negative_table`.
    +
    +# Train word2vec
    +
    +Hivemall provides `train_word2vec` function to train word vector by word2vec algorithms.
    +The default model is `"skipgram"`.
    +
    +> #### Note
    +> You must pass `n` argumet to the number of words in training documents: `select sum(size(words)) from train_docs;`.
    +
    +## Train Skip-Gram
    +
    +In skip-gram model,
    +word vectors are trained to predict the nearby words.
    +For example, given a sentence like a `"alice", "was", "beginning", "to"`,
    +`"was"` vector is learnt to predict `"alice"` ,`"beginning"` and `"to"`.
    +
    +```sql
    +select sum(size(words)) from train_docs;
    +set hivevar:n=418953; -- previous query return value
    +
    +drop table skipgram;
    +create table skipgram as
    +select
    +  train_word2vec(
    +    r.negative_table,
    +    l.words,
    +    "-n ${n} -win 5 -neg 15 -iter 5 -dim 100 -model skipgram"
    +  )
    +from
    +  train_docs l
    +  cross join negative_table r
    +;
    +```
    +
    +When word is treated as int istead of string,
    +you may need to transform wordid of int to word of string by `join` statement.
    +
    +```sql
    +drop table skipgram;
    +
    +create table skipgram as
    +select
    +  r.word, t.i, t.wi
    +from (
    +  select
    +    train_word2vec(
    +      r.negative_table,
    +      l.wordsint,
    +      "-n 418953 -win 5 -neg 15 -iter 5"
    +    ) as (wordid, i, wi)
    +  from
    +    train_docs l
    +  cross join
    +    negative_table r
    +) t
    +join freq r on (t.wordid = r.wordid)
    +;
    +```
    +
    +## Train CBoW
    +
    +In CBoW model,
    +word vectors are trained to be predicted the nearby words.
    +For example, given a sentence like a `"alice", "was", "beginning", "to"`,
    +`"alice"` ,`"beginning"` and `"to"` vectors are learnt to predict `"was"` vector.
    +
    +```sql
    +drop table cbow;
    +
    +create table cbow as
    +select
    +  train_word2vec(
    +    r.negative_table,
    +    l.words,
    +    "-n 418953 -win 5 -neg 15 -iter 5 -model cbow"
    --- End diff --
    
    -n 418953 should be -n ${n}


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545337
  
    --- Diff: docs/gitbook/embedding/word2vec.md ---
    @@ -0,0 +1,399 @@
    +<!--
    +  Licensed to the Apache Software Foundation (ASF) under one
    +  or more contributor license agreements.  See the NOTICE file
    +  distributed with this work for additional information
    +  regarding copyright ownership.  The ASF licenses this file
    +  to you under the Apache License, Version 2.0 (the
    +  "License"); you may not use this file except in compliance
    +  with the License.  You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +  Unless required by applicable law or agreed to in writing,
    +  software distributed under the License is distributed on an
    +  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +  KIND, either express or implied.  See the License for the
    +  specific language governing permissions and limitations
    +  under the License.
    +-->
    +
    +Word Embedding is a powerful tool for many tasks,
    +e.g. finding similar words,
    +feature vectors for supervised machine learning task and word analogy,
    +such as `king - man + woman =~ queen`.
    +In word embedding,
    +each word represents a low dimension and dense vector.
    +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec).
    +
    +The papers introduce the method are as follows:
    +
    +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality
    +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013.
    +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013.
    +
    +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling.
    +Hivemall enables you to train your sequence data such as,
    +but not limited to, documents based on word2vec.
    +This article gives usage instructions of the feature.
    +
    +<!-- toc -->
    +
    +> #### Note
    +> This feature is supported from Hivemall v0.5-rc.? or later.
    +
    +# Prepare document data
    +
    +Assume that you already have `docs` table which contains many documents as string format with unique index:
    +
    +```sql
    +select * FROM docs;
    +```
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | "Alice was beginning to get very tired of sitting by her sister on the bank ..." |
    +|  ...  | ... |
    +
    +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html).
    +
    +```sql
    +drop table docs_words;
    +create table docs_words as
    +  select
    +    docid,
    +    tokenize(doc, true) as words
    +  FROM
    +    docs
    +;
    +```
    +
    +This table shows tokenized document.
    +
    +| docId | doc |
    +|:----: |:----|
    +|   0   | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] |
    +|  ...  | ... |
    +
    +Then, you count frequency up per word and remove low frequency words from the vocabulary.
    +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly.
    +
    +```sql
    +set hivevar:mincount=5;
    +
    +drop table freq;
    +create table freq as
    +select
    +  row_number() over () - 1 as wordid,
    +  word,
    +  freq
    +from (
    +  select
    +    word,
    +    COUNT(*) as freq
    +  from
    +    docs_words
    +  LATERAL VIEW explode(words) lTable as word
    +  group by
    +    word
    +) t
    +where freq >= ${mincount}
    +;
    +```
    +
    +Hivemall's word2vec supports two type words; string and int.
    +String type tends to use huge memory during training.
    +On the other hand, int type tends to use less memory.
    +If you train on small dataset, we recommend using string type,
    +because memory usage can be ignored and HiveQL is more simple.
    +If you train on large dataset, we recommend using int type,
    +because it saves memory during training.
    +
    +# Create sub-sampling table
    +
    +Sub-sampling table is stored a sub-sampling probability per word.
    +
    +The sub-sampling probability of word $$w_i$$ is computed by the following equation:
    +
    +$$
    +\begin{aligned}
    +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}
    +\end{aligned}
    +$$
    +
    +During word2vec training,
    --- End diff --
    
    remove line break after `,`.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543986
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI));
    +            }
    +        } else {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI);
    --- End diff --
    
    `rawDoc.get(i)` may return null.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13383603/badge)](https://coveralls.io/builds/13383603)
    
    Coverage decreased (-0.8%) to 40.138% when pulling **e50756111378e6a173aba7574e73435364ff42d0 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13434811/badge)](https://coveralls.io/builds/13434811)
    
    Coverage decreased (-0.8%) to 40.149% when pulling **2b66e5e3dbf1408c0719735dcbdf05555938f3bb on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by coveralls <gi...@git.apache.org>.
Github user coveralls commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    
    [![Coverage Status](https://coveralls.io/builds/13371457/badge)](https://coveralls.io/builds/13371457)
    
    Coverage decreased (-0.8%) to 40.156% when pulling **ad2b2911b5a6ebb7b43b4981bf5ff4424425a292 on nzw0301:skipgram** into **c2b95783cf9d6fc1646a48ac928e96152eab98c6 on apache:master**.



---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r140413234
  
    --- Diff: core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java ---
    @@ -0,0 +1,125 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.unsupervised;
    --- End diff --
    
    Please move package from `hivemall.unsupervised` to `hivemall.embedding`.


---

[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on the issue:

    https://github.com/apache/incubator-hivemall/pull/116
  
    @nzw0301 Please rebase to master resolving ^ conflicts.


---

[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec

Posted by myui <gi...@git.apache.org>.
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r140725446
  
    --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java ---
    @@ -0,0 +1,117 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +
    +public abstract class AbstractWord2VecModel {
    +    // cached sigmoid function parameters
    +    protected final int MAX_SIGMOID = 6;
    --- End diff --
    
    Constants should be `static final`.


---