You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/27 18:28:26 UTC

[incubator-mxnet] 01/02: Add gluon.text vocab/embedding demo (#18)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch nlp_toolkit
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 14d7499431e4e90efdebc832250f316c81dd019b
Author: Aston Zhang <22...@users.noreply.github.com>
AuthorDate: Mon Mar 26 23:03:14 2018 -0700

    Add gluon.text vocab/embedding demo (#18)
    
    * Add word embedding example
    
    * clean
    
    * Add text descriptions
---
 example/gluon/word_embedding.ipynb   | 1049 ++++++++++++++++++++++++++++++++++
 python/mxnet/gluon/text/embedding.py |    6 +
 2 files changed, 1055 insertions(+)

diff --git a/example/gluon/word_embedding.ipynb b/example/gluon/word_embedding.ipynb
new file mode 100644
index 0000000..f3c3217
--- /dev/null
+++ b/example/gluon/word_embedding.ipynb
@@ -0,0 +1,1049 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Using Pre-trained Word Embeddings"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Here we introduce how to use pre-trained word embeddings via `mxnet.gluon.text`. \n",
+    "\n",
+    "The used GloVe and fastText word embeddings in this tutorial are from the following sources:\n",
+    "\n",
+    "* GloVe project website:https://nlp.stanford.edu/projects/glove/\n",
+    "* fastText project website:https://fasttext.cc/\n",
+    "\n",
+    "Let us first import the following packages."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:34.447895Z",
+     "start_time": "2018-03-27T00:03:33.503038Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "from mxnet import gluon\n",
+    "from mxnet import nd\n",
+    "from mxnet.gluon import text\n",
+    "from collections import Counter"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Creating Vocabulary with Word Embeddings\n",
+    "\n",
+    "As a common use case, let us index words, attach pre-trained word embeddings for them, and use such embeddings in `gluon` in just a few lines of code.\n",
+    "\n",
+    "### Creating Vocabulary from Data Sets\n",
+    "\n",
+    "To begin with, suppose that we have a simple text data set in the string format. We can count word frequency in the data set."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:34.453636Z",
+     "start_time": "2018-03-27T00:03:34.449760Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "data = \" hello world \\n hello nice world \\n hi world \\n\"\n",
+    "counter = text.utils.count_tokens_from_str(data)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The obtained `counter` has key-value pairs whose keys are words and values are word frequencies. This allows us to filter out infrequent words via `Vocabulary` arguments such as `max_size` and `min_freq`. Suppose that we want to build indices for all the keys in counter. We need a `Vocabulary` instance with counter as its argument."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:34.459747Z",
+     "start_time": "2018-03-27T00:03:34.456473Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "vocab = text.vocab.Vocabulary(counter)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "To attach word embedding to indexed words in `vocab`, let us go on to create a fastText word embedding instance by specifying the embedding name `fasttext` and the pre-trained file name `wiki.simple.vec`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.199585Z",
+     "start_time": "2018-03-27T00:03:34.462702Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/astonz/WorkDocs/Programs/git_repo/mxnet/python/mxnet/gluon/text/embedding.py:264: UserWarning: At line 1 of the pre-trained token embedding file: token 111051 with 1-dimensional vector [300.0] is likely a header and is skipped.\n",
+      "  'skipped.' % (line_num, token, elems))\n"
+     ]
+    }
+   ],
+   "source": [
+    "fasttext_simple = text.embedding.create('fasttext', file_name='wiki.simple.vec')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "So we can attach word embedding `fasttext_simple` to indexed words in `vocab`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.214582Z",
+     "start_time": "2018-03-27T00:03:53.201953Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "vocab.set_embedding(fasttext_simple)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "To see other pre-trained file names under the fastText word embedding, we can use `text.embedding.get_file_names`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.240556Z",
+     "start_time": "2018-03-27T00:03:53.217839Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['crawl-300d-2M.vec',\n",
+       " 'wiki.aa.vec',\n",
+       " 'wiki.ab.vec',\n",
+       " 'wiki.ace.vec',\n",
+       " 'wiki.ady.vec']"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "text.embedding.get_file_names('fasttext')[:5]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The created vocabulary `vocab` includes four different words and a special unknown token. Let us check the size of `vocab`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.250542Z",
+     "start_time": "2018-03-27T00:03:53.243313Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "5"
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(vocab)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "By default, the vector of any token that is unknown to `vocab` is a zero vector. Its length is equal to the vector dimension of the fastText word embeddings: 300."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.262146Z",
+     "start_time": "2018-03-27T00:03:53.253051Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(300,)"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "vocab.embedding['beautiful'].shape"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The first five elements of the vector of any unknown token are zeros."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.273198Z",
+     "start_time": "2018-03-27T00:03:53.264987Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "\n",
+       "[ 0.  0.  0.  0.  0.]\n",
+       "<NDArray 5 @cpu(0)>"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "vocab.embedding['beautiful'][:5]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us check the shape of the vectors of words 'hello' and 'world' from `vocab`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.283862Z",
+     "start_time": "2018-03-27T00:03:53.276282Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(2, 300)"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "vocab.embedding['hello', 'world'].shape"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-26T23:29:07.340108Z",
+     "start_time": "2018-03-26T23:29:07.334790Z"
+    }
+   },
+   "source": [
+    "We can access the first five elements of the vectors of 'hello' and 'world'."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.296482Z",
+     "start_time": "2018-03-27T00:03:53.287022Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "\n",
+       "[[ 0.39567     0.21454    -0.035389   -0.24299    -0.095645  ]\n",
+       " [ 0.10444    -0.10858     0.27212     0.13299    -0.33164999]]\n",
+       "<NDArray 2x5 @cpu(0)>"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "vocab.embedding['hello', 'world'][:, :5]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Using Pre-trained Word Embeddings in  `gluon.nn.Embedding`\n",
+    "\n",
+    "To demonstrate how to use pre-trained word embeddings in the `gluon` package, let us first obtain indices of the words 'hello' and 'world'."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.306574Z",
+     "start_time": "2018-03-27T00:03:53.300400Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[2, 1]"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "vocab['hello', 'world']"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can obtain the vectors for the words 'hello' and 'world' by specifying their indices (2 and 1) and the weight matrix `vocab.embedding.idx_to_vec` in `gluon.nn.Embedding`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.327785Z",
+     "start_time": "2018-03-27T00:03:53.309979Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "\n",
+       "[[ 0.39567     0.21454    -0.035389   -0.24299    -0.095645  ]\n",
+       " [ 0.10444    -0.10858     0.27212     0.13299    -0.33164999]]\n",
+       "<NDArray 2x5 @cpu(0)>"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "input_dim, output_dim = vocab.embedding.idx_to_vec.shape\n",
+    "layer = gluon.nn.Embedding(input_dim, output_dim)\n",
+    "layer.initialize()\n",
+    "layer.weight.set_data(vocab.embedding.idx_to_vec)\n",
+    "layer(nd.array([2, 1]))[:, :5]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Creating Vocabulary from Pre-trained Word Embeddings\n",
+    "\n",
+    "We can also create vocabulary by using vocabulary of pre-trained word embeddings, such as GloVe. Below are a few pre-trained file names under the GloVe word embedding."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:03:53.338638Z",
+     "start_time": "2018-03-27T00:03:53.330822Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['glove.42B.300d.txt',\n",
+       " 'glove.6B.50d.txt',\n",
+       " 'glove.6B.100d.txt',\n",
+       " 'glove.6B.200d.txt',\n",
+       " 'glove.6B.300d.txt']"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "text.embedding.get_file_names('glove')[:5]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "For simplicity of demonstration, we use a smaller word embedding file, such as the 50-dimensional one. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:04.229138Z",
+     "start_time": "2018-03-27T00:03:53.341827Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "glove_6b50d = text.embedding.create('glove', file_name='glove.6B.50d.txt')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Now we create vocabulary by using all the tokens from `glove_6b50d`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.032364Z",
+     "start_time": "2018-03-27T00:04:04.231212Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "vocab = text.vocab.Vocabulary(Counter(glove_6b50d.idx_to_token))\n",
+    "vocab.set_embedding(glove_6b50d)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Below shows the size of `vocab` including a special unknown token."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.042843Z",
+     "start_time": "2018-03-27T00:04:06.034933Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "400001"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(vocab.idx_to_token)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can access attributes of `vocab`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.056449Z",
+     "start_time": "2018-03-27T00:04:06.046106Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "71421\n",
+      "beautiful\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(vocab['beautiful'])\n",
+    "print(vocab.idx_to_token[71421])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Applications of Word Embeddings\n",
+    "\n",
+    "To apply word embeddings, we need to define cosine similarity. It can compare similarity of two vectors."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.067188Z",
+     "start_time": "2018-03-27T00:04:06.059379Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "from mxnet import nd\n",
+    "def cos_sim(x, y):\n",
+    "    return nd.dot(x, y) / (nd.norm(x) * nd.norm(y))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The range of cosine similarity between two vectors is between -1 and 1. The larger the value, the similarity between two vectors."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.272263Z",
+     "start_time": "2018-03-27T00:04:06.070098Z"
+    }
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "[ 1.]\n",
+      "<NDArray 1 @cpu(0)>\n",
+      "\n",
+      "[-1.]\n",
+      "<NDArray 1 @cpu(0)>\n"
+     ]
+    }
+   ],
+   "source": [
+    "x = nd.array([1, 2])\n",
+    "y = nd.array([10, 20])\n",
+    "z = nd.array([-1, -2])\n",
+    "\n",
+    "print(cos_sim(x, y))\n",
+    "print(cos_sim(x, z))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Word Similarity\n",
+    "\n",
+    "Given an input word, we can find the nearest $k$ words from the vocabulary (400,000 words excluding the unknown token) by similarity. The similarity between any pair of words can be represented by the cosine similarity of their vectors."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.292283Z",
+     "start_time": "2018-03-27T00:04:06.274721Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def norm_vecs_by_row(x):\n",
+    "    return x / nd.sqrt(nd.sum(x * x, axis=1)).reshape((-1,1))\n",
+    "\n",
+    "def get_knn(vocab, k, word):\n",
+    "    word_vec = vocab.embedding[word].reshape((-1, 1))\n",
+    "    vocab_vecs = norm_vecs_by_row(vocab.embedding.idx_to_vec)\n",
+    "    dot_prod = nd.dot(vocab_vecs, word_vec)\n",
+    "    indices = nd.topk(dot_prod.reshape((len(vocab), )), k=k+2, ret_typ='indices')\n",
+    "    indices = [int(i.asscalar()) for i in indices]\n",
+    "    # Remove unknown and input tokens.\n",
+    "    return vocab.to_tokens(indices[2:])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us find the 5 most similar words of 'baby' from the vocabulary (size: 400,000 words)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.687950Z",
+     "start_time": "2018-03-27T00:04:06.295771Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['babies', 'boy', 'girl', 'newborn', 'pregnant']"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_knn(vocab, 5, 'baby')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can verify the cosine similarity of vectors of 'baby' and 'babies'."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:06.698920Z",
+     "start_time": "2018-03-27T00:04:06.691103Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "\n",
+       "[ 0.83871299]\n",
+       "<NDArray 1 @cpu(0)>"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "cos_sim(vocab.embedding['baby'], vocab.embedding['babies'])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us find the 5 most similar words of 'computers' from the vocabulary."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:07.084357Z",
+     "start_time": "2018-03-27T00:04:06.702292Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['computer', 'phones', 'pcs', 'machines', 'devices']"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_knn(vocab, 5, 'computers')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us find the 5 most similar words of 'run' from the vocabulary."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:07.504323Z",
+     "start_time": "2018-03-27T00:04:07.087221Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['running', 'runs', 'went', 'start', 'ran']"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_knn(vocab, 5, 'run')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us find the 5 most similar words of 'beautiful' from the vocabulary."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:07.967072Z",
+     "start_time": "2018-03-27T00:04:07.507039Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['lovely', 'gorgeous', 'wonderful', 'charming', 'beauty']"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_knn(vocab, 5, 'beautiful')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Word Analogy\n",
+    "\n",
+    "We can also apply pre-trained word embeddings to the word analogy problem. For instance, \"man : woman :: son : daughter\" is an analogy. The word analogy completion problem is defined as: for analogy 'a : b :: c : d', given teh first three words 'a', 'b', 'c', find 'd'. The idea is to find the most similar word vector for vec('c') + (vec('b')-vec('a')).\n",
+    "\n",
+    "In this example, we will find words by analogy from the 400,000 indexed words in `vocab`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:08.040101Z",
+     "start_time": "2018-03-27T00:04:07.973776Z"
+    }
+   },
+   "outputs": [],
+   "source": [
+    "def get_top_k_by_analogy(vocab, k, word1, word2, word3):\n",
+    "    word_vecs = vocab.embedding[word1, word2, word3]\n",
+    "    word_diff = (word_vecs[1] - word_vecs[0] + word_vecs[2]).reshape((-1, 1))\n",
+    "    vocab_vecs = norm_vecs_by_row(vocab.embedding.idx_to_vec)\n",
+    "    dot_prod = nd.dot(vocab_vecs, word_diff)\n",
+    "    indices = nd.topk(dot_prod.reshape((len(vocab), )), k=k+1, ret_typ='indices')\n",
+    "    indices = [int(i.asscalar()) for i in indices]\n",
+    "\n",
+    "    # Filter out unknown tokens.\n",
+    "    if vocab.to_tokens(indices[0]) == vocab.unknown_token:\n",
+    "        return vocab.to_tokens(indices[1:])\n",
+    "    else:\n",
+    "        return vocab.to_tokens(indices[:-1])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Complete word analogy 'man : woman :: son :'."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:08.519697Z",
+     "start_time": "2018-03-27T00:04:08.051060Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['daughter']"
+      ]
+     },
+     "execution_count": 28,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_top_k_by_analogy(vocab, 1, 'man', 'woman', 'son')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let us verify the cosine similarity between vec('son')+vec('woman')-vec('man') and vec('daughter')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:08.535690Z",
+     "start_time": "2018-03-27T00:04:08.522548Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "\n",
+       "[ 0.9658342]\n",
+       "<NDArray 1 @cpu(0)>"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def cos_sim_word_analogy(vocab, word1, word2, word3, word4):\n",
+    "    words = [word1, word2, word3, word4]\n",
+    "    vecs = vocab.embedding[words]\n",
+    "    return cos_sim(vecs[1] - vecs[0] + vecs[2], vecs[3])\n",
+    "\n",
+    "cos_sim_word_analogy(vocab, 'man', 'woman', 'son', 'daughter')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Complete word analogy 'beijing : china :: tokyo : '."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:08.939664Z",
+     "start_time": "2018-03-27T00:04:08.538918Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['japan']"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_top_k_by_analogy(vocab, 1, 'beijing', 'china', 'tokyo')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Complete word analogy 'bad : worst :: big : '."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:09.319291Z",
+     "start_time": "2018-03-27T00:04:08.942078Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['biggest']"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_top_k_by_analogy(vocab, 1, 'bad', 'worst', 'big')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Complete word analogy 'do : did :: go :'."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {
+    "ExecuteTime": {
+     "end_time": "2018-03-27T00:04:09.735225Z",
+     "start_time": "2018-03-27T00:04:09.323663Z"
+    }
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "['went']"
+      ]
+     },
+     "execution_count": 32,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "get_top_k_by_analogy(vocab, 1, 'do', 'did', 'go')"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.1"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/python/mxnet/gluon/text/embedding.py b/python/mxnet/gluon/text/embedding.py
index 1839212..fcbc6df 100644
--- a/python/mxnet/gluon/text/embedding.py
+++ b/python/mxnet/gluon/text/embedding.py
@@ -155,6 +155,8 @@ class TokenEmbedding(object):
 
     Properties
     ----------
+    idx_to_token : list of strs
+        A list of indexed tokens where the list indices and the token indices are aligned.
     idx_to_vec : mxnet.ndarray.NDArray
         For all the indexed tokens in this embedding, this NDArray maps each token's index to an
         embedding vector.
@@ -285,6 +287,10 @@ class TokenEmbedding(object):
             self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
 
     @property
+    def idx_to_token(self):
+        return self._idx_to_token
+
+    @property
     def idx_to_vec(self):
         return self._idx_to_vec
 

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.