You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/13 21:54:37 UTC
[incubator-mxnet] branch master updated: resubmit #8763 #9394 to
contrib (#9406)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e4957c3 resubmit #8763 #9394 to contrib (#9406)
e4957c3 is described below
commit e4957c3692bdae20665f8245b592b17ef79bc946
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sat Jan 13 13:54:31 2018 -0800
resubmit #8763 #9394 to contrib (#9406)
---
python/mxnet/contrib/__init__.py | 2 +
python/mxnet/contrib/{ => text}/__init__.py | 14 +-
python/mxnet/contrib/text/constants.py | 344 +++++++++++++
python/mxnet/contrib/text/embedding.py | 669 +++++++++++++++++++++++++
python/mxnet/contrib/text/glossary.py | 142 ++++++
python/mxnet/contrib/text/indexer.py | 231 +++++++++
python/mxnet/contrib/text/utils.py | 79 +++
python/mxnet/registry.py | 17 +
tests/python/unittest/test_contrib_text.py | 727 ++++++++++++++++++++++++++++
9 files changed, 2216 insertions(+), 9 deletions(-)
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py
index 2730bc4..21c7771 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/__init__.py
@@ -26,3 +26,5 @@ from . import ndarray as nd
from . import autograd
from . import tensorboard
+
+from . import text
diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/text/__init__.py
similarity index 81%
copy from python/mxnet/contrib/__init__.py
copy to python/mxnet/contrib/text/__init__.py
index 2730bc4..fff2b94 100644
--- a/python/mxnet/contrib/__init__.py
+++ b/python/mxnet/contrib/text/__init__.py
@@ -16,13 +16,9 @@
# under the License.
# coding: utf-8
-"""Experimental contributions"""
+"""This module includes utilities for indexing and embedding text."""
-from . import symbol
-from . import ndarray
-
-from . import symbol as sym
-from . import ndarray as nd
-
-from . import autograd
-from . import tensorboard
+from . import utils
+from . import indexer
+from . import embedding
+from . import glossary
diff --git a/python/mxnet/contrib/text/constants.py b/python/mxnet/contrib/text/constants.py
new file mode 100644
index 0000000..b69e5d9
--- /dev/null
+++ b/python/mxnet/contrib/text/constants.py
@@ -0,0 +1,344 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Read text files and load embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+UNKNOWN_IDX = 0
+
+APACHE_REPO_URL = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/'
+
+GLOVE_PRETRAINED_FILE_SHA1 = \
+ {'glove.42B.300d.zip': 'f8e722b39578f776927465b71b231bae2ae8776a',
+ 'glove.6B.zip': 'b64e54f1877d2f735bdd000c1d7d771e25c7dfdc',
+ 'glove.840B.300d.zip': '8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d',
+ 'glove.twitter.27B.zip': 'dce69c404025a8312c323197347695e81fd529fc'}
+
+GLOVE_PRETRAINED_ARCHIVE_SHA1 = \
+ {'glove.42B.300d.txt': '876767977d6bd4d947c0f84d44510677bc94612a',
+ 'glove.6B.50d.txt': '21bf566a9d27f84d253e0cd4d4be9dcc07976a6d',
+ 'glove.6B.100d.txt': '16b1dbfaf35476790bd9df40c83e2dfbd05312f1',
+ 'glove.6B.200d.txt': '17d0355ddaa253e298ede39877d1be70f99d9148',
+ 'glove.6B.300d.txt': '646443dd885090927f8215ecf7a677e9f703858d',
+ 'glove.840B.300d.txt': '294b9f37fa64cce31f9ebb409c266fc379527708',
+ 'glove.twitter.27B.25d.txt':
+ '767d80889d8c8a22ae7cd25e09d0650a6ff0a502',
+ 'glove.twitter.27B.50d.txt':
+ '9585f4be97e286339bf0112d0d3aa7c15a3e864d',
+ 'glove.twitter.27B.100d.txt':
+ '1bbeab8323c72332bd46ada0fc3c99f2faaa8ca8',
+ 'glove.twitter.27B.200d.txt':
+ '7921c77a53aa5977b1d9ce3a7c4430cbd9d1207a'}
+
+FAST_TEXT_FILE_SHA1 = \
+ {'wiki.ab.vec': '9d89a403a9a866d3da8dd8cfab849f59ee499343',
+ 'wiki.ace.vec': '85d00074f7a08626f39da6a0c8a5cfa250096ab9',
+ 'wiki.ady.vec': '9d17d74f0348224cdebf8a831e61af0825f8952d',
+ 'wiki.aa.vec': '5cce30fc85471572c498f278bbe495184577363e',
+ 'wiki.af.vec': '999e64bcd8dab8de42cb1feceeca360def35324d',
+ 'wiki.ak.vec': '6092b8af335c2dc93e8df2bbf1d715f01e637bb4',
+ 'wiki.sq.vec': 'd07ffed553f5eb4756d0a1548a7ba9a51a52f7c6',
+ 'wiki.als.vec': '96052e96870695cca50857b5fde5f9f42219139a',
+ 'wiki.am.vec': 'dff7fcdd8f5ba0638ab9e1758a89800766156d72',
+ 'wiki.ang.vec': 'a7c30e02422d97d23a0701279c5c1c03159130a5',
+ 'wiki.ar.vec': 'c46e2142f799cc385bd25f0c0a8943ca565505a4',
+ 'wiki.an.vec': '5b4c2b1de5c04e4e0be83841410ca84c47305d21',
+ 'wiki.arc.vec': 'fd3ad743103f80cde9cfc048d7ca509e50efb35a',
+ 'wiki.hy.vec': '21f9259d04cfd22db446a45d3622af225f00cf20',
+ 'wiki.roa_rup.vec': 'e31a44353cd84b976586c8df35a2ab58318120f0',
+ 'wiki.as.vec': 'cad5883b5147cbe6cdbf604f65cabdb675a59258',
+ 'wiki.ast.vec': '89a90357101953b7c292697fd050c00fe5c38ac5',
+ 'wiki.av.vec': '99976a63ca8c4231f808fd4314f0433db35e290d',
+ 'wiki.ay.vec': 'be359dad25b2c742d3abfa94c5f5db13f86c730e',
+ 'wiki.az.vec': '9581d55d9056ad398a153c37b502f3a07867d091',
+ 'wiki.bm.vec': 'f36a19c95e90865f6518d4487e59f363b47bd865',
+ 'wiki.bjn.vec': '5f134cf288e8042dcd048a3ee76159aab42c7288',
+ 'wiki.map_bms.vec': 'e7deab5fdd38fa3331b1bcb4a16432b38c512e21',
+ 'wiki.ba.vec': '22147ee16b2d163cc88d09a035264fd0c10dab68',
+ 'wiki.eu.vec': '5e72f4ef93666971fea5d2180b354e0a0821ba91',
+ 'wiki.bar.vec': '96130f1f2e5bffdd06c202ad4472e5234020980a',
+ 'wiki.be.vec': '6cf81322cd7b046a7f02ec4c4960ad27045383fa',
+ 'wiki.bn.vec': '6fc3bfd9af455719f55bee0bea31b11afc70cf06',
+ 'wiki.bh.vec': 'ab2d29017afa015c49566a6d9bf75393c23ac4c0',
+ 'wiki.bpy.vec': 'c2bb15487c4bdb8fa869772694300ae1fee73896',
+ 'wiki.bi.vec': '15785220cd6e6c86cc87e7d3f3322a5541a4fe5d',
+ 'wiki.bs.vec': 'c4943a290819ceae1611dd11179b40aab0df0471',
+ 'wiki.br.vec': 'df44e16abd2017e2a1b6c6588ee02779b19907f6',
+ 'wiki.bug.vec': '942d8f7dadde5faa33aa72862501434f48e29f60',
+ 'wiki.bg.vec': '7c1cc6d0c52b038e4b7173259b0c009f242cf486',
+ 'wiki.my.vec': 'e7c7989e32b23ca1a9caf534cc65ecaf9e1b9112',
+ 'wiki.bxr.vec': 'eaf767690c6b194605ae778719212e3874873d4c',
+ 'wiki.zh_yue.vec': 'd2ac1ab9eb1a908797644f83f259c90cb3c1a350',
+ 'wiki.ca.vec': 'f5971edee11c939f6a7accfd33a9a45caa54141a',
+ 'wiki.ceb.vec': 'b8516a55537b8f80c927d77d95cdf7e4ff849a05',
+ 'wiki.bcl.vec': 'd4117b5c443438ddfa608b10a5be2c2501817e7e',
+ 'wiki.ch.vec': '46803f3a1734f6a7b0d8cb053bbb86a6915d02e9',
+ 'wiki.cbk_zam.vec': '6fef47b4559eec402ce371de20dfb018acd6347d',
+ 'wiki.ce.vec': '1d94b0168a773895b23889f7f07d7cf56c11a360',
+ 'wiki.chr.vec': '8501bf86b41074ed6c8d15b9209ef7ce83122e70',
+ 'wiki.chy.vec': '26c87688551ffe3a0c7a5952e894306651e62131',
+ 'wiki.ny.vec': '4e066fe113630fdfbcaf8844cc4ad64388db98d0',
+ 'wiki.zh.vec': '117ab34faa80e381641fbabf3a24bc8cfba44050',
+ 'wiki.cho.vec': 'cec6778f025fa9ae4134046c6c3a6291bd9c63f9',
+ 'wiki.cv.vec': '9cdb0bee5a0fea030def85597dba7108f21b0424',
+ 'wiki.zh_classical.vec': '840981c83dd8e5cb02d1cd695e2fe0870941316c',
+ 'wiki.kw.vec': 'f9eaa35a7e4f077f6de85c7801f74582f91b52c1',
+ 'wiki.co.vec': 'af876a918594e5541207bc12f17bfc4268df7b93',
+ 'wiki.cr.vec': '61dd9f044b7dfa56dcf1c3c07c7504c569420528',
+ 'wiki.crh.vec': 'c0d2310a1207fcacc94b25b149420b33bf835015',
+ 'wiki.hr.vec': '0c96f9af092cf8a84b03aec1426cd23921671489',
+ 'wiki.cs.vec': 'f3ec1502aeee6a550d8cf784273fa62f61419a4e',
+ 'wiki.da.vec': '526947dab1ffbc1465c7a766f2bca4de50676b08',
+ 'wiki.dv.vec': 'e135ba97c711a021bc3317db2b95db5212c17658',
+ 'wiki.nl.vec': 'd796ee27e37b7d1d464e03c265c31ab62b52533e',
+ 'wiki.nds_nl.vec': '1cd96d12e78e5cd3f65ca2773a17696bda387b9f',
+ 'wiki.dz.vec': '4cc1c6cf4aa4676d40a3145d5d4623569e430f89',
+ 'wiki.pa.vec': '4939d0db77a5b28d7d5aab0fab4f999d93b2053e',
+ 'wiki.arz.vec': '5e904087043b91f4945dd708f4230fdf51360132',
+ 'wiki.eml.vec': 'de6be7a2ffdda226eec730dd54b4c614bd7f5dca',
+ 'wiki.en.vec': 'c1e418f144ceb332b4328d27addf508731fa87df',
+ 'wiki.myv.vec': '7de0927fd3d65677de7f770b3bd57c73b58df85d',
+ 'wiki.eo.vec': 'b56998fd69f66755b722a9481a9bdaf10f62c9aa',
+ 'wiki.et.vec': '64d56b66c02d5e49b1b66a85854d67d2dd9ebd41',
+ 'wiki.ee.vec': 'f2212f58ec082321bc9b93873cd22702d0a64d64',
+ 'wiki.ext.vec': '456c5632b13a0f136cd180ebe2dda67b83f78397',
+ 'wiki.fo.vec': 'eead8ddc7bb74b12b16784723abf802bb51f844d',
+ 'wiki.hif.vec': '49697cf784814d3f1a47559724028e0fc0940d36',
+ 'wiki.fj.vec': 'c70fca34a7e43143600c54d7bf199b88846ac6f2',
+ 'wiki.fi.vec': '91d19baae994d7e556b5b5938be2dc6013f9c706',
+ 'wiki.frp.vec': '0eb70a613ccf807c7308c1f62535f0606465029d',
+ 'wiki.fr.vec': 'b092229005a65d8683a4112852fe6eb8161a6917',
+ 'wiki.fur.vec': 'd4a595cffa1abcdcf4229ba15277179ce5d20bc6',
+ 'wiki.ff.vec': '57ea8febb24ba8ac4421ec97ed8918d44c69f42c',
+ 'wiki.gag.vec': 'c82ec7a5d081f0673661824f4fc34345dee255f0',
+ 'wiki.gl.vec': '8888bb8f3d70b36729b9ae479fe3765e0c083862',
+ 'wiki.gan.vec': 'aeea01c2c4a7c44d6e8c31845560baf43d7afb9c',
+ 'wiki.ka.vec': '8b92b73f27f9b77818211e053a33985589de7c62',
+ 'wiki.de.vec': '2ed2696afe55f023b0040b238d9a47e5fedfe48b',
+ 'wiki.glk.vec': '20a7759075916e10531f5b3577302353cef565cd',
+ 'wiki.gom.vec': '5a1193d9e5d49d06354c14e2b7c01bea176e13f1',
+ 'wiki.got.vec': 'cc5aaf4c305f4f1f788b4829e644823f8495a23a',
+ 'wiki.el.vec': '6f034271390feaa6f9d7d16f933ddef637755979',
+ 'wiki.kl.vec': '390406cc33e02f86cfaf7ae273193679924f7413',
+ 'wiki.gn.vec': '98594af7897c5a1f35885ddecc77556a7e7ae981',
+ 'wiki.gu.vec': 'f9e13452eb63d92bea44c7c3db8fba9945c7000e',
+ 'wiki.ht.vec': '5039dfb58a074ac046813f2dae81159be8c5213f',
+ 'wiki.hak.vec': '9e83512d34c7f81739492bf0abbb25ff1ef88573',
+ 'wiki.ha.vec': '677a24efeeb1bcb8c0a931407775f18b18e875ae',
+ 'wiki.haw.vec': 'c23a50529dc010401c99833c8f990c1b27843db3',
+ 'wiki.he.vec': '55534560247394669e3f5c169136770c93bc2708',
+ 'wiki.hz.vec': '7605e06dd708920f73a80473816a8d684c116bd8',
+ 'wiki.mrj.vec': 'aa1c1ecba1ffd6b42c8d9659a8a04ab328ae1650',
+ 'wiki.hi.vec': '8049bb8604bc049d48bd934e27b0e184c480a413',
+ 'wiki.ho.vec': 'ef6b84d508d4d0a4c4cf90facaca1eebe62b2187',
+ 'wiki.hu.vec': 'cd777e9efca3d4bd97c89f01690cfa4840d9c46f',
+ 'wiki.is.vec': 'ae0b018f92b3e218f2dacb2045a8f0a0446788a5',
+ 'wiki.io.vec': 'af0c480c5872bff31d82e767c1116da2a6be0c00',
+ 'wiki.ig.vec': 'd2d1643b4fb1a18a4d002cf2969073f7f201b3b2',
+ 'wiki.ilo.vec': 'c0e43835a3f4e0033ea5d7c6ff189982b2f26a05',
+ 'wiki.id.vec': 'c49d5c9bec89114599427f6c12a5bda2e5523dfd',
+ 'wiki.ia.vec': '2a348dc924638efc20c34785852b0837364aed76',
+ 'wiki.ie.vec': '01b0d11c0e7397418e73853d220e97bdcf7a8961',
+ 'wiki.iu.vec': 'ed77a1d7b0faeeb1352b1c4fc1e69971e1e21174',
+ 'wiki.ik.vec': '4d5d4f7a6426720e07d0faeb51b5babfa4acf44a',
+ 'wiki.ga.vec': 'caaa5b2167a499893313ac1aa38416a6a0fe9a24',
+ 'wiki.it.vec': 'ac4a985e85ffae48047034e2603d804bf126caa9',
+ 'wiki.jam.vec': '6d51e384c56330097c2531fdbf4e74418909e388',
+ 'wiki.ja.vec': '7a2b1af1e46d795410692a002e40fa3085135f69',
+ 'wiki.jv.vec': '2ff7927d3ff04b8208133497b3778ede00ea463f',
+ 'wiki.kbd.vec': 'f5b8dbe47a7fae702232b5680b070ef6e865539e',
+ 'wiki.kab.vec': 'e3b73d41267d8d4cd42f6cc5a0c05dc4e021bf74',
+ 'wiki.xal.vec': 'b738222d84cb8c8fdb2b30a7219aa5d3bdc2f61c',
+ 'wiki.kn.vec': '32763f4f860f0d081f3aabf3e7d17b7858e7d877',
+ 'wiki.kr.vec': 'c919463e96e4fe36dd5bd73be0c5cd144d4d4f91',
+ 'wiki.pam.vec': '8fbd31e70d0ca0c61eb1a152efaa8ecb29180967',
+ 'wiki.krc.vec': '0c6ef043d51e5f337a309804f1db180fa0bb2cb8',
+ 'wiki.kaa.vec': 'd990d3b9bd511d2d630f923099a6b9110231b2ed',
+ 'wiki.ks.vec': 'f0a69830a3f661c107503772cc6bd5e345f0c8d6',
+ 'wiki.csb.vec': '649cb2692f08414987c875dc331022567d367497',
+ 'wiki.kk.vec': '6343b2b31bad2e13d03a110b91c38fab4adc01cd',
+ 'wiki.km.vec': '64f7fff1df90b1f7241b232e901f76223a3719e0',
+ 'wiki.ki.vec': 'c4e373e2ea13f7fa1e95b0733365e4b3fc8b2cc8',
+ 'wiki.rw.vec': 'af2ec410da6519a86ba21004c8b4c7fde768a91c',
+ 'wiki.ky.vec': '13b0ae3f23822317a0243bd9182105c631c834b3',
+ 'wiki.rn.vec': '9df628e8c25d928d3e9d127b624f79fd99ff8f4e',
+ 'wiki.kv.vec': '164dc44d701b9d606a45f0b0446076adc3858dca',
+ 'wiki.koi.vec': '4001f0617fe0fdd3b22116b304f497b7b16c6e4c',
+ 'wiki.kg.vec': '379575f4c6e1c4b73e311ddf01b7a85afd047d7c',
+ 'wiki.ko.vec': '042c85a788c2778cca538cf716b8a78f0d7fa823',
+ 'wiki.kj.vec': 'adf29c1a3aa5dd53d85e04d68aa11a26c0eaf6c8',
+ 'wiki.ku.vec': '4d3a2401527dd9ba6be2b0cd31f6cd3edebadce9',
+ 'wiki.ckb.vec': 'adb2fef309f1d93f429442b9c16c1564192c58f3',
+ 'wiki.lad.vec': 'c510e520cde97050bf1cbeb36f2b90e6348ceed4',
+ 'wiki.lbe.vec': 'e72e5ea054334580f72fbe446a726d2b4962851d',
+ 'wiki.lo.vec': '7c83f82b80c49b8eab21f62ecdb3681b8bda40a6',
+ 'wiki.ltg.vec': 'ec2f13d1290bd54afcaa74569e66e43e9bfef264',
+ 'wiki.la.vec': '9ea6286a0581084533db8d6ee96e0b7d15166543',
+ 'wiki.lv.vec': 'ef6b549f96e22718f513d47a611d3d6bc001a164',
+ 'wiki.lez.vec': '8e579b984a500ad89fc66767bfd7319766bd669b',
+ 'wiki.lij.vec': '4ff5bb405c820e4119f0636efc301da15a08c00a',
+ 'wiki.li.vec': '0fb9ec4ac93676d8ef651692062bc3d7f6ae0843',
+ 'wiki.ln.vec': '70b6a286b42958e25cb80824e0d8f1aee2de6dde',
+ 'wiki.lt.vec': '58d3ebef24e5e31be1a8318b45c08ebb16ad775a',
+ 'wiki.olo.vec': 'cbadb4cada4dc579d0becdac93dfb479d76bf6c8',
+ 'wiki.jbo.vec': 'c90481946aa4b6b304528292612ae620f6549f3e',
+ 'wiki.lmo.vec': 'a89414d9ceee4823622258f18936f67faf7e06e7',
+ 'wiki.nds.vec': '7bf293149c08226e05bcf0442ac6e601162b9ffd',
+ 'wiki.dsb.vec': 'e49a647a441fbf011ac5411dd6005e8725b9a65d',
+ 'wiki.lg.vec': 'b096f5248dfbb343dc4696c97ea253510e1c4ef9',
+ 'wiki.lb.vec': 'b146f23628c84e64314a35a5b6cc65a33777e22d',
+ 'wiki.mk.vec': '85a3d3f13fa88ffde023d2326c65bdded4983dff',
+ 'wiki.mai.vec': '7f513ff36e485b19f91f83b30c32dd82e9e497f6',
+ 'wiki.mg.vec': '0808252740909d6129f672584311263e7b2adadc',
+ 'wiki.ms.vec': '458e1a079799a54cdc0a7b78c7fa1729d2683a6d',
+ 'wiki.ml.vec': '2b70fe76e8cf199a18551de782784a21e8db0b66',
+ 'wiki.mt.vec': '81f4c1d84dd4cc4276d59cb903fcc9aba46be981',
+ 'wiki.gv.vec': '993a7ee31bdacc91763dad656aa6c2947b873473',
+ 'wiki.mi.vec': 'e8acf9c7c2ab840a192c563aa776201a88e4ca89',
+ 'wiki.mr.vec': '2cd6cf88bfdfb24850d345749ce0cfea8d65829e',
+ 'wiki.mh.vec': '8c5dbbcb8ad08b9c8b39151fa56d553d116d1b5a',
+ 'wiki.mzn.vec': 'aefad49237808acab99e1ca8eeaaf531666f261d',
+ 'wiki.mhr.vec': '39f62e292336cabc364f0d1913540b881b406393',
+ 'wiki.cdo.vec': '95e8196bf76323dbabab1b8a49ba4d677af3ccea',
+ 'wiki.zh_min_nan.vec': 'f91ccb013e200bb7ed560082ddf4bdd9c2f315bb',
+ 'wiki.min.vec': '3bb0fa596cf27a1d165c55684bebdc8d40cb8ad7',
+ 'wiki.xmf.vec': 'dc1923cfd1a7002d5d60426b60e6756854ab4a14',
+ 'wiki.mwl.vec': '3d10a218242b94fcc3981aa3beb012b701827a55',
+ 'wiki.mdf.vec': 'b16099ce0283a241339716eac41cfd99fdea7f36',
+ 'wiki.mo.vec': '9824ebe366bc52d84e66d1c0cc72b5f7ebb46110',
+ 'wiki.mn.vec': '7cef7ecdf9d98484d9b598b25d0e717dba6acfd9',
+ 'wiki.mus.vec': 'bb94534fdeee4df77ae3e27c252c8874f69a307d',
+ 'wiki.nah.vec': 'c52e01cf4479fb7ec91ef39f298e8f97aeb6496e',
+ 'wiki.na.vec': 'fbe1444b21e1a5885a619cf2a8607fcefca3c8db',
+ 'wiki.nv.vec': 'f5a6ea213bfe95c82cb22b53b4965df8b67ffeab',
+ 'wiki.ng.vec': '8577634e236133980243be0a6fb3c02ad2bb5290',
+ 'wiki.nap.vec': '6c9bd8ce1e85ee679b25189fd6f6d36afb119b6c',
+ 'wiki.ne.vec': '1045d7876f947cd4602d9ca79f7c4323a5d3a52d',
+ 'wiki.new.vec': '51f6c0b4ef1aee9fad4ab1cb69a7479db35e39a5',
+ 'wiki.pih.vec': 'a6a867cef441a06c926205daa9e405aaf58e8f63',
+ 'wiki.nrm.vec': 'b4cb941b126b26fa045c5fc75a490a31a969101c',
+ 'wiki.frr.vec': 'cde62af939cb2de35e341cef2c74813802a58ed4',
+ 'wiki.lrc.vec': 'c1ae4fb79a19d44bfe8f601f0a30fbec841fa612',
+ 'wiki.se.vec': 'f46b35ee6b893c2f12dd1b929bbc2b8120cbcd8d',
+ 'wiki.nso.vec': 'a906271509c2b343df35d1471509492bbfa883aa',
+ 'wiki.no.vec': 'd52e8019d7cc48569c8c3b514d2b1bd10261b5c0',
+ 'wiki.nn.vec': '35aeab89ffeca0377accbbd3bf18b81913c75448',
+ 'wiki.nov.vec': '5455c6e8463b1c43dd073e3e177702fb9a1dd834',
+ 'wiki.ii.vec': '755a6b8ffa664e342c2ab72847af343c47f46c70',
+ 'wiki.oc.vec': 'cc1833492899d75571148c2c305591f53d63f0b1',
+ 'wiki.cu.vec': 'e8eb72eb7fbc224b62ed32dbd897c8c7f6cc5c0a',
+ 'wiki.or.vec': 'a6b120fe536b6c0133b077dca0043c3bc97eef0b',
+ 'wiki.om.vec': '91789a8d9f9284f7e71e4bb8d9a60eae4af4adca',
+ 'wiki.os.vec': '791b26cc300e9a1f0a08c7b2213a264e41ce30d6',
+ 'wiki.pfl.vec': '0ad9b7f3ae13f909f12835107432fee4c4ed3031',
+ 'wiki.pi.vec': '07a5d05e5363e8b8b132220a71de4bdc0a623cfc',
+ 'wiki.pag.vec': '03f71faf060c4eb33802275279967349c0337553',
+ 'wiki.pap.vec': '8cd98267cc55a4f9de80212e29651ddf7a9e83fd',
+ 'wiki.ps.vec': '64f1bec5d5b937289199ceae2e1da6557ce48852',
+ 'wiki.pdc.vec': '401e24d0fb9b0ae9e06a5c700684361f58727fcf',
+ 'wiki.fa.vec': '09b6cc685c895c66b853af9617787d3ab0891e2c',
+ 'wiki.pcd.vec': 'd2e8e7321b6f1bce94c563cb8ef8af2b45cc3e48',
+ 'wiki.pms.vec': 'e30bda8d33d61db43243c157b9ac2feeaff316c8',
+ 'wiki.pl.vec': 'd031adb6f83eda0364a861dcbf5ef779b5951c0b',
+ 'wiki.pnt.vec': 'a9efbf962a895e1d08dde5fd56797dd03abb421e',
+ 'wiki.pt.vec': '7f11ebdb0cbf5929b38319f1e977d2c13bcd741b',
+ 'wiki.qu.vec': '58de8c8290e8bc8f2a6a677312e28457113437b2',
+ 'wiki.ksh.vec': '4c3bb4f12073532b6fb7cc6c2be5e53319ef5b65',
+ 'wiki.rmy.vec': '309fb92222b03f3bd4f2260c02bbd1e3f3d3aba7',
+ 'wiki.ro.vec': 'c088ea2752d5ec8b42e32410c191a14839ae8a1f',
+ 'wiki.rm.vec': '5d3144b47a0dd98648a6df0636384ab2a010ad7b',
+ 'wiki.ru.vec': '7514a2c60ee4118abb451ed32a0d61cb52dec384',
+ 'wiki.rue.vec': 'fe539e0ea0bbbfd3ee06bd0c5521a035c7361ec5',
+ 'wiki.sah.vec': '202470467194a1cbdcd571b14ef68371a29b38d9',
+ 'wiki.sm.vec': '88c2c57ca483626b052403418cb4372d72352bc9',
+ 'wiki.bat_smg.vec': 'cb3aef58da2011183b39fca64cabf3d9d7a62f4b',
+ 'wiki.sg.vec': '7b9c8294c060bd10839650afd1f247b950aa819d',
+ 'wiki.sa.vec': '7fed78d1d7674453b9876ee99aeeeba85ea46699',
+ 'wiki.sc.vec': 'dba8dc7754ef04b1ba0cd702d94eea9575cde91c',
+ 'wiki.stq.vec': '1bf88af29f1d86cac16042a5bea6b1651c96a8c1',
+ 'wiki.sco.vec': '4625a5ad90a57f994be9b3aa4f8f3ecda941a821',
+ 'wiki.gd.vec': 'f4b513598a1bf0f0d5b6521ea8ce363e9596cb97',
+ 'wiki.sr.vec': '3cf09f476f55a92fdd2880f7ba336656ab232736',
+ 'wiki.sh.vec': '016691ecb26ace442731d92b1265e5c6c3d8ca5f',
+ 'wiki.st.vec': '963646055d12873b1c83b0eef8649ecaf473d42e',
+ 'wiki.sn.vec': '8dbb1019dcc8f842a8c0f550295ae697f8e1b7e0',
+ 'wiki.scn.vec': 'bde043a235551e1643506774c5d9b61ecf2fc424',
+ 'wiki.szl.vec': '0573cf888ec70b459b0596d34814fe60fd69f190',
+ 'wiki.simple.vec': '55267c50fbdf4e4ae0fbbda5c73830a379d68795',
+ 'wiki.sd.vec': '36852d1253496e598fbd9b9009f07f454a6bea5b',
+ 'wiki.si.vec': 'd05ed6a0bc1ee56e5d2e5f881d47372095f6eb0c',
+ 'wiki.sk.vec': '98759aacf7352d49a51390fae02030776510ae13',
+ 'wiki.sl.vec': 'b26997c0ed1de26a47b11efdc26ac1e7f189fa54',
+ 'wiki.so.vec': '294756b60b03fe57cb08abd8d677d6a717b40bc8',
+ 'wiki.azb.vec': 'e23af0a436b97434813c3cb14ed114cc5b352faa',
+ 'wiki.es.vec': '2f41401aa0925167176bcd7a6770423d891dfef5',
+ 'wiki.srn.vec': 'faee05e550f5b08809a9ae5586ac4b08c9a1c359',
+ 'wiki.su.vec': '25e864495acb6d280bab0e62480f68550c9ceed4',
+ 'wiki.sw.vec': '8e70d207dbbd14e60a48e260a23fbf284a8e9f06',
+ 'wiki.ss.vec': '488546a3b2f88f549c50ae9f32f1997cc441b039',
+ 'wiki.sv.vec': 'eab83ae36701139696477b91b6e8d292ef175053',
+ 'wiki.tl.vec': 'd508e229ced7201510999e76d583de3ff2339d8b',
+ 'wiki.ty.vec': 'b881f60b8c75a71864d9847a17961d368f3058fc',
+ 'wiki.tg.vec': '6a5cd5bfe571ca0359b66d21bf6950553213f42d',
+ 'wiki.ta.vec': 'b66b5358527b1f3a6a421ab26464a3c1e75e18af',
+ 'wiki.roa_tara.vec': 'b3fcb01ff0bac53a0ba08c5c0c411f26ee83a95a',
+ 'wiki.tt.vec': '913bb3a11da6f8142b3bbec3ef065162d9350f1d',
+ 'wiki.te.vec': 'e71dcf3cc45da1bcdae5e431324025bd2026d0c8',
+ 'wiki.tet.vec': 'f38fe0e76b9b08ff652689eeee42c4fdadd9a47e',
+ 'wiki.th.vec': '1d6e0d525392a1042d017534f6c320c5a0afd345',
+ 'wiki.bo.vec': '2e9358e03dcfa09da23d2e1499d84b10348fd8a9',
+ 'wiki.ti.vec': 'c769fbc99bbb4138a40231e573685c7948d4a4c4',
+ 'wiki.tpi.vec': '407b96d235f54f3e0be9dc23a3bab89c6593a621',
+ 'wiki.to.vec': '64d512665b55e9ef9a3915e8167347be79310fa0',
+ 'wiki.ts.vec': '00f8229e2f230afd388221c0f823a1de9fc0e443',
+ 'wiki.tn.vec': '39f45f3fa86645bb25c54150204abcd51cc1048c',
+ 'wiki.tcy.vec': '388b1d89642fcc790b688e9643b3d19e14d66f40',
+ 'wiki.tum.vec': 'bfbe43364724af882a520d2edcc2ce049c7357cd',
+ 'wiki.tr.vec': '13234aa1bf5f99e81d933482b3b83c3e4bf6c85e',
+ 'wiki.tk.vec': '33ae577f77d339ab7a0dff88855b8d5c974d0aef',
+ 'wiki.tyv.vec': 'e8f9a36dc58e4108c553f96e247a877a099ab5ba',
+ 'wiki.tw.vec': 'f329b667d70d9f0b753e55e1b1579b5a5191d3bd',
+ 'wiki.udm.vec': '336a8526f22e177faac69573661dc9c3ce36591f',
+ 'wiki.uk.vec': '77f7737b9f88eac2b3e130ea8abb8886336fd0c6',
+ 'wiki.hsb.vec': '3dc7830544c58535bed308c552d609e13b973502',
+ 'wiki.ur.vec': 'cb8132102152a958df72bd3e25f1a72abb4c9c76',
+ 'wiki.ug.vec': '586d2febafaf17c9187c599ffd7b96e559103c34',
+ 'wiki.uz.vec': '11c3a76dae12b454f693811e33ae2e60015743e2',
+ 'wiki.ve.vec': 'b7d2947501de1c30a9f8496d5efae20c051104e1',
+ 'wiki.vec.vec': 'ae4b055fba21974e56beecab3a95f9dc24a62fd0',
+ 'wiki.vep.vec': 'a38a781fde24f4d7b52aa8bc450b9949dd4e1808',
+ 'wiki.vi.vec': 'bc84245b52b2e212e28dc6856c0693ce9845a9c5',
+ 'wiki.vo.vec': 'c830988b6965bfce2f932b1be193f7d1f755f411',
+ 'wiki.fiu_vro.vec': '168a71a2b1c478e6810fa5dce9612d8bf8a273dc',
+ 'wiki.wa.vec': '18f9ca1a585e1d18c3630029141a2e19d7d34a8e',
+ 'wiki.war.vec': '1f5d443d6f612b59a53820dd6f39fd886a6ad30f',
+ 'wiki.cy.vec': '32d976a9bfc4dd6e39328c906eead0f597bd9e25',
+ 'wiki.vls.vec': '07e8636908c057b9870ce4b98c7130d460cf882a',
+ 'wiki.fy.vec': 'd4beef537b7ff142a3986513879ff51a9ec14a7b',
+ 'wiki.pnb.vec': '35f38862d3d83012d6db7baa8a4105e3e0a416e7',
+ 'wiki.wo.vec': '2ad96a7a9e640bc0dbcf316b1f414b92802dcb8e',
+ 'wiki.wuu.vec': 'e1cbae1d3ad52329d0f36ada764016fbacf07049',
+ 'wiki.xh.vec': 'bf37f741b0b75953281d11df2b4d80100df9e666',
+ 'wiki.yi.vec': '299d61958b7dcc38774768f1489121384726d860',
+ 'wiki.yo.vec': 'e35c8aff2924ba07936be9d0d94bd298f09702a4',
+ 'wiki.diq.vec': '77f3c370d1d77806fafe368cf788af550ff607dd',
+ 'wiki.zea.vec': 'ee12db26aab3f2b3b2745a298ef414e7aeb5a058',
+ 'wiki.za.vec': 'e3a0e58bd2e5b1891c71f1f7e37ff71997a20361',
+ 'wiki.zu.vec': '4b244b9697a8280e6646842c5fc81bb3a6bc8ec7'}
diff --git a/python/mxnet/contrib/text/embedding.py b/python/mxnet/contrib/text/embedding.py
new file mode 100644
index 0000000..2996f1e
--- /dev/null
+++ b/python/mxnet/contrib/text/embedding.py
@@ -0,0 +1,669 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Text token embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+import io
+import logging
+import os
+import tarfile
+import warnings
+import zipfile
+
+from . import constants as C
+from .indexer import TokenIndexer
+from ... import ndarray as nd
+from ... import registry
+
+
+class TokenEmbedding(TokenIndexer):
+ """Token embedding base class.
+
+
+ To load token embeddings from an externally hosted pre-trained
+ token embedding file, such as those of GloVe and FastText, use
+ `TokenEmbedding.create(embedding_name, pretrained_file_name)`. To get all
+ the available `embedding_name` and `pretrained_file_name`, use
+ `TokenEmbedding.get_embedding_and_pretrained_file_names()`.
+
+ Alternatively, to load embedding vectors from a custom pre-trained token
+ embedding file, use :class:`~mxnet.text.embedding.CustomEmbedding`.
+
+ For every unknown token, if its representation `self.unknown_token` is
+ encountered in the pre-trained token embedding file, index 0 of
+ `self.idx_to_vec` maps to the pre-trained token embedding vector loaded from
+ the file; otherwise, index 0 of `self.idx_to_vec` maps to the token
+ embedding vector initialized by `init_unknown_vec`.
+
+ If a token is encountered multiple times in the pre-trained token embedding
+ file, only the first-encountered token embedding vector will be loaded and
+ the rest will be skipped.
+
+ For the same token, its index and embedding vector may vary across different
+ instances of :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ def __init__(self, **kwargs):
+ super(TokenEmbedding, self).__init__(**kwargs)
+
+ @classmethod
+ def _get_download_file_name(cls, pretrained_file_name):
+ return pretrained_file_name
+
+ @classmethod
+ def _get_pretrained_file_url(cls, pretrained_file_name):
+ repo_url = os.environ.get('MXNET_GLUON_REPO', C.APACHE_REPO_URL)
+ embedding_cls = cls.__name__.lower()
+
+ url_format = '{repo_url}gluon/embeddings/{cls}/{file_name}'
+ return url_format.format(repo_url=repo_url,
+ cls=embedding_cls,
+ file_name=cls._get_download_file_name(pretrained_file_name))
+
+ @classmethod
+ def _get_pretrained_file(cls, embedding_root, pretrained_file_name):
+ from ...gluon.utils import check_sha1, download
+ embedding_cls = cls.__name__.lower()
+ embedding_root = os.path.expanduser(embedding_root)
+ url = cls._get_pretrained_file_url(pretrained_file_name)
+
+ embedding_dir = os.path.join(embedding_root, embedding_cls)
+ pretrained_file_path = os.path.join(embedding_dir, pretrained_file_name)
+ downloaded_file = os.path.basename(url)
+ downloaded_file_path = os.path.join(embedding_dir, downloaded_file)
+
+ expected_file_hash = cls.pretrained_file_name_sha1[pretrained_file_name]
+
+ if hasattr(cls, 'pretrained_archive_name_sha1'):
+ expected_downloaded_hash = \
+ cls.pretrained_archive_name_sha1[downloaded_file]
+ else:
+ expected_downloaded_hash = expected_file_hash
+
+ if not os.path.exists(pretrained_file_path) \
+ or not check_sha1(pretrained_file_path, expected_file_hash):
+ download(url, downloaded_file_path, sha1_hash=expected_downloaded_hash)
+
+ ext = os.path.splitext(downloaded_file)[1]
+ if ext == '.zip':
+ with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
+ zf.extractall(embedding_dir)
+ elif ext == '.gz':
+ with tarfile.open(downloaded_file_path, 'r:gz') as tar:
+ tar.extractall(path=embedding_dir)
+ return pretrained_file_path
+
+ def _load_embedding(self, pretrained_file_path, elem_delim,
+ init_unknown_vec, encoding='utf8'):
+ """Load embedding vectors from the pre-trained token embedding file.
+
+
+ For every unknown token, if its representation `self.unknown_token` is
+ encountered in the pre-trained token embedding file, index 0 of
+ `self.idx_to_vec` maps to the pre-trained token embedding vector loaded
+ from the file; otherwise, index 0 of `self.idx_to_vec` maps to the text
+ embedding vector initialized by `init_unknown_vec`.
+
+ If a token is encountered multiple times in the pre-trained text
+ embedding file, only the first-encountered token embedding vector will
+ be loaded and the rest will be skipped.
+ """
+
+ pretrained_file_path = os.path.expanduser(pretrained_file_path)
+
+ if not os.path.isfile(pretrained_file_path):
+ raise ValueError('`pretrained_file_path` must be a valid path to '
+ 'the pre-trained token embedding file.')
+
+ logging.info('Loading pre-trained token embedding vectors from %s',
+ pretrained_file_path)
+ vec_len = None
+ all_elems = []
+ tokens = set()
+ loaded_unknown_vec = None
+ line_num = 0
+ with io.open(pretrained_file_path, 'r', encoding=encoding) as f:
+ for line in f:
+ line_num += 1
+ elems = line.rstrip().split(elem_delim)
+
+ assert len(elems) > 1, 'At line %d of the pre-trained text ' \
+ 'embedding file: the data format of the ' \
+ 'pre-trained token embedding file %s is ' \
+ 'unexpected.' \
+ % (line_num, pretrained_file_path)
+
+ token, elems = elems[0], [float(i) for i in elems[1:]]
+
+ if token == self.unknown_token and loaded_unknown_vec is None:
+ loaded_unknown_vec = elems
+ tokens.add(self.unknown_token)
+ elif token in tokens:
+ warnings.warn('At line %d of the pre-trained token embedding '
+ 'file: the embedding vector for token %s has '
+ 'been loaded and a duplicate embedding for the '
+ 'same token is seen and skipped.'
+ % (line_num, token))
+ elif len(elems) == 1:
+ warnings.warn('At line %d of the pre-trained text '
+ 'embedding file: token %s with 1-dimensional '
+ 'vector %s is likely a header and is '
+ 'skipped.' % (line_num, token, elems))
+ else:
+ if vec_len is None:
+ vec_len = len(elems)
+ # Reserve a vector slot for the unknown token at the
+ # very beggining because the unknown index is 0.
+ all_elems.extend([0] * vec_len)
+ else:
+ assert len(elems) == vec_len, \
+ 'At line %d of the pre-trained token embedding ' \
+ 'file: the dimension of token %s is %d but the ' \
+ 'dimension of previous tokens is %d. Dimensions ' \
+ 'of all the tokens must be the same.' \
+ % (line_num, token, len(elems), vec_len)
+ all_elems.extend(elems)
+ self._idx_to_token.append(token)
+ self._token_to_idx[token] = len(self._idx_to_token) - 1
+ tokens.add(token)
+
+ self._vec_len = vec_len
+ self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
+
+ if loaded_unknown_vec is None:
+ self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(
+ shape=self.vec_len)
+ else:
+ self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
+
+ @property
+ def vec_len(self):
+ return self._vec_len
+
+ @property
+ def idx_to_vec(self):
+ return self._idx_to_vec
+
+ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
+ """Look up embedding vectors of tokens.
+
+
+ Parameters
+ ----------
+ tokens : str or list of strs
+ A token or a list of tokens.
+ lower_case_backup : bool, default False
+ If False, each token in the original case will be looked up; if
+ True, each token in the original case will be looked up first, if
+ not found in the keys of the property `token_to_idx`, the token
+ in the lower case will be looked up.
+
+
+ Returns
+ -------
+ mxnet.ndarray.NDArray:
+ The embedding vector(s) of the token(s). According to numpy
+ conventions, if `tokens` is a string, returns a 1-D NDArray of shape
+ `self.vec_len`; if `tokens` is a list of strings, returns a 2-D
+ NDArray of shape=(len(tokens), self.vec_len).
+ """
+
+ to_reduce = False
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ to_reduce = True
+
+ if not lower_case_backup:
+ indices = [self.token_to_idx.get(token, C.UNKNOWN_IDX)
+ for token in tokens]
+ else:
+ indices = [self.token_to_idx[token] if token in self.token_to_idx
+ else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
+ for token in tokens]
+
+ vecs = nd.Embedding(nd.array(indices), self.idx_to_vec,
+ self.idx_to_vec.shape[0], self.idx_to_vec.shape[1])
+
+ return vecs[0] if to_reduce else vecs
+
+ def update_token_vectors(self, tokens, new_vectors):
+ """Updates embedding vectors for tokens.
+
+
+ Parameters
+ ----------
+ tokens : str or a list of strs
+ A token or a list of tokens whose embedding vector are to be
+ updated.
+ new_vectors : mxnet.ndarray.NDArray
+ An NDArray to be assigned to the embedding vectors of `tokens`.
+ Its length must be equal to the number of `tokens` and its width
+ must be equal to the dimension of embeddings of the glossary. If
+ `tokens` is a singleton, it must be 1-D or 2-D. If `tokens` is a
+ list of multiple strings, it must be 2-D.
+ """
+
+ assert self.idx_to_vec is not None, \
+ 'The property `idx_to_vec` has not been properly set.'
+
+ if not isinstance(tokens, list) or len(tokens) == 1:
+ assert isinstance(new_vectors, nd.NDArray) and \
+ len(new_vectors.shape) in [1, 2], \
+ '`new_vectors` must be a 1-D or 2-D NDArray if `tokens` is a ' \
+ 'singleton.'
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ if len(new_vectors.shape) == 1:
+ new_vectors = new_vectors.expand_dims(0)
+
+ else:
+ assert isinstance(new_vectors, nd.NDArray) and \
+ len(new_vectors.shape) == 2, \
+ '`new_vectors` must be a 2-D NDArray if `tokens` is a list ' \
+ 'of multiple strings.'
+ assert new_vectors.shape == (len(tokens), self.vec_len), \
+ 'The length of new_vectors must be equal to the number of tokens ' \
+ 'and the width of new_vectors must be equal to the dimension of ' \
+ 'embeddings of the glossary.'
+
+ indices = []
+ for token in tokens:
+ if token in self.token_to_idx:
+ indices.append(self.token_to_idx[token])
+ else:
+ raise ValueError('Token %s is unknown. To update the embedding '
+ 'vector for an unknown token, please specify '
+ 'it explicitly as the `unknown_token` %s in '
+ '`tokens`. This is to avoid unintended '
+ 'updates.' %
+ (token, self.idx_to_token[C.UNKNOWN_IDX]))
+
+ self._idx_to_vec[nd.array(indices)] = new_vectors
+
+ @staticmethod
+ def register(embedding_cls):
+ """Registers a new token embedding.
+
+
+ Once an embedding is registered, we can create an instance of this
+ embedding with :func:`~mxnet.text.embedding.TokenEmbedding.create`.
+
+
+ Examples
+ --------
+ >>> @mxnet.text.embedding.TokenEmbedding.register
+ ... class MyTextEmbed(mxnet.text.embedding.TokenEmbedding):
+ ... def __init__(self, pretrained_file_name='my_pretrain_file'):
+ ... pass
+ >>> embed = mxnet.text.embedding.TokenEmbedding.create('MyTokenEmbed')
+ >>> print(type(embed))
+ <class '__main__.MyTokenEmbed'>
+ """
+
+ register_text_embedding = registry.get_register_func(
+ TokenEmbedding, 'token embedding')
+ return register_text_embedding(embedding_cls)
+
+ @staticmethod
+ def create(embedding_name, **kwargs):
+ """Creates an instance of :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Creates a token embedding instance by loading embedding vectors from an
+ externally hosted pre-trained token embedding file, such as those
+ of GloVe and FastText. To get all the valid `embedding_name` and
+ `pretrained_file_name`, use `mxnet.text.embedding.TokenEmbedding.
+ get_embedding_and_pretrained_file_names()`.
+
+
+ Parameters
+ ----------
+ embedding_name : str
+ The token embedding name (case-insensitive).
+
+
+ Returns
+ -------
+ :class:`~mxnet.text.glossary.TokenEmbedding`:
+ A token embedding instance that loads embedding vectors from an
+ externally hosted pre-trained token embedding file.
+ """
+
+ create_text_embedding = registry.get_create_func(
+ TokenEmbedding, 'token embedding')
+ return create_text_embedding(embedding_name, **kwargs)
+
+ @classmethod
+ def _check_pretrained_file_names(cls, pretrained_file_name):
+ """Checks if a pre-trained token embedding file name is valid.
+
+
+ Parameters
+ ----------
+ pretrained_file_name : str
+ The pre-trained token embedding file.
+ """
+
+ embedding_name = cls.__name__.lower()
+ if pretrained_file_name not in cls.pretrained_file_name_sha1:
+ raise KeyError('Cannot find pretrained file %s for token embedding '
+ '%s. Valid pretrained files for embedding %s: %s' %
+ (pretrained_file_name, embedding_name,
+ embedding_name,
+ ', '.join(cls.pretrained_file_name_sha1.keys())))
+
+ @staticmethod
+ def get_embedding_and_pretrained_file_names(embedding_name=None):
+ """Get valid token embedding names and their pre-trained file names.
+
+
+ To load token embedding vectors from an externally hosted pre-trained
+ token embedding file, such as those of GloVe and FastText, one should
+ use `mxnet.text.embedding.TokenEmbedding.create(embedding_name,
+ pretrained_file_name)`. This method returns all the valid names of
+ `pretrained_file_name` for the specified `embedding_name`. If
+ `embedding_name` is set to None, this method returns all the valid names
+ of `embedding_name` with associated `pretrained_file_name`.
+
+
+ Parameters
+ ----------
+ embedding_name : str or None, default None
+ The pre-trained token embedding name.
+
+
+ Returns
+ -------
+ dict or list:
+ A list of all the valid pre-trained token embedding file names
+ (`pretrained_file_name`) for the specified token embedding name
+ (`embedding_name`). If the text embeding name is set to None,
+ returns a dict mapping each valid token embedding name to a list
+ of valid pre-trained files (`pretrained_file_name`). They can be
+ plugged into `mxnet.text.embedding.TokenEmbedding.create(
+ embedding_name, pretrained_file_name)`.
+ """
+
+ text_embedding_reg = registry.get_registry(TokenEmbedding)
+
+ if embedding_name is not None:
+ if embedding_name not in text_embedding_reg:
+ raise KeyError('Cannot find `embedding_name` %s. Use '
+ '`get_embedding_and_pretrained_file_names('
+ 'embedding_name=None).keys()` to get all the '
+ 'valid embedding names.' % embedding_name)
+ return list(text_embedding_reg[
+ embedding_name].pretrained_file_name_sha1.keys())
+ else:
+ return {embedding_name: list(
+ embedding_cls.pretrained_file_name_sha1.keys())
+ for embedding_name, embedding_cls in
+ registry.get_registry(TokenEmbedding).items()}
+
+
+@TokenEmbedding.register
+class GloVe(TokenEmbedding):
+ """The GloVe word embedding.
+
+
+ GloVe is an unsupervised learning algorithm for obtaining vector
+ representations for words. Training is performed on aggregated global
+ word-word co-occurrence statistics from a corpus, and the resulting
+ representations showcase interesting linear substructures of the word vector
+ space. (Source from https://nlp.stanford.edu/projects/glove/)
+
+ Reference:
+
+ GloVe: Global Vectors for Word Representation.
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning.
+ https://nlp.stanford.edu/pubs/glove.pdf
+
+ Website:
+
+ https://nlp.stanford.edu/projects/glove/
+
+ To get the updated URLs to the externally hosted pre-trained token embedding
+ files, visit https://nlp.stanford.edu/projects/glove/
+
+ License for pre-trained embeddings:
+
+ https://opendatacommons.org/licenses/pddl/
+
+
+ Parameters
+ ----------
+ pretrain_file : str, default 'glove.840B.300d.txt'
+ The name of the pre-trained token embedding file.
+ embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ The root directory for storing embedding-related files.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ # Map a pre-trained token embedding archive file and its SHA-1 hash.
+ pretrained_archive_name_sha1 = C.GLOVE_PRETRAINED_FILE_SHA1
+
+ # Map a pre-trained token embedding file and its SHA-1 hash.
+ pretrained_file_name_sha1 = C.GLOVE_PRETRAINED_ARCHIVE_SHA1
+
+ @classmethod
+ def _get_download_file_name(cls, pretrained_file_name):
+ # Map a pretrained embedding file to its archive to download.
+ src_archive = {archive.split('.')[1]: archive for archive in
+ GloVe.pretrained_archive_name_sha1.keys()}
+ archive = src_archive[pretrained_file_name.split('.')[1]]
+ return archive
+
+ def __init__(self, pretrained_file_name='glove.840B.300d.txt',
+ embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ init_unknown_vec=nd.zeros, **kwargs):
+ GloVe._check_pretrained_file_names(pretrained_file_name)
+
+ super(GloVe, self).__init__(**kwargs)
+ pretrained_file_path = GloVe._get_pretrained_file(embedding_root,
+ pretrained_file_name)
+
+ self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
+
+
+@TokenEmbedding.register
+class FastText(TokenEmbedding):
+ """The fastText word embedding.
+
+
+ FastText is an open-source, free, lightweight library that allows users to
+ learn text representations and text classifiers. It works on standard,
+ generic hardware. Models can later be reduced in size to even fit on mobile
+ devices. (Source from https://fasttext.cc/)
+
+ References:
+
+ Enriching Word Vectors with Subword Information.
+ Piotr Bojanowski, Edouard Grave, Armand Joulin, and Tomas Mikolov.
+ https://arxiv.org/abs/1607.04606
+
+ Bag of Tricks for Efficient Text Classification.
+ Armand Joulin, Edouard Grave, Piotr Bojanowski, and Tomas Mikolov.
+ https://arxiv.org/abs/1607.01759
+
+ FastText.zip: Compressing text classification models.
+ Armand Joulin, Edouard Grave, Piotr Bojanowski, Matthijs Douze, Herve Jegou,
+ and Tomas Mikolov.
+ https://arxiv.org/abs/1612.03651
+
+ Website:
+
+ https://fasttext.cc/
+
+ To get the updated URLs to the externally hosted pre-trained token embedding
+ files, visit
+ https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md
+
+ License for pre-trained embeddings:
+
+ https://creativecommons.org/licenses/by-sa/3.0/
+
+
+ Parameters
+ ----------
+ pretrain_file : str, default 'wiki.en.vec'
+ The name of the pre-trained token embedding file.
+ embed_root : str, default os.path.join('~', '.mxnet', 'embeddings')
+ The root directory for storing embedding-related files.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ # Map a pre-trained token embedding file and its SHA-1 hash.
+ pretrained_file_name_sha1 = C.FAST_TEXT_FILE_SHA1
+
+ def __init__(self, pretrained_file_name='wiki.simple.vec',
+ embedding_root=os.path.join('~', '.mxnet', 'embeddings'),
+ init_unknown_vec=nd.zeros, **kwargs):
+ FastText._check_pretrained_file_names(pretrained_file_name)
+
+ super(FastText, self).__init__(**kwargs)
+ pretrained_file_path = FastText._get_pretrained_file(embedding_root,
+ pretrained_file_name)
+
+ self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
+
+
+class CustomEmbedding(TokenEmbedding):
+ """User-defined token embedding.
+
+ This is to load embedding vectors from a user-defined pre-trained text
+ embedding file.
+
+ Denote by '<ed>' the argument `elem_delim`. Denote by <v_ij> the j-th
+ element of the token embedding vector for <token_i>, the expected format of
+ a custom pre-trained token embedding file is:
+
+ '<token_1><ed><v_11><ed><v_12><ed>...<ed><v_1k>\\\\n<token_2><ed><v_21><ed>
+ <v_22><ed>...<ed><v_2k>\\\\n...'
+
+ where k is the length of the embedding vector `vec_len`.
+
+
+ Parameters
+ ----------
+ pretrain_file_path : str
+ The path to the custom pre-trained token embedding file.
+ elem_delim : str, default ' '
+ The delimiter for splitting a token and every embedding vector element
+ value on the same line of the custom pre-trained token embedding file.
+ unknown_vec : callback
+ The callback used to initialize the embedding vector for the unknown
+ token.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+
+ def __init__(self, pretrained_file_path, elem_delim=' ', encoding='utf8',
+ init_unknown_vec=nd.zeros, **kwargs):
+ super(CustomEmbedding, self).__init__(**kwargs)
+ self._load_embedding(pretrained_file_path, elem_delim, init_unknown_vec,
+ encoding)
diff --git a/python/mxnet/contrib/text/glossary.py b/python/mxnet/contrib/text/glossary.py
new file mode 100644
index 0000000..4de082b
--- /dev/null
+++ b/python/mxnet/contrib/text/glossary.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Index text tokens and load their embeddings."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from ... import ndarray as nd
+from .embedding import TokenEmbedding
+
+
+class Glossary(TokenEmbedding):
+ """Indexing and embedding for text tokens in a glossary.
+
+
+ For each indexed token in a glossary, an embedding vector will be associated
+ with it. Such embedding vectors can be loaded from externally hosted or
+ custom pre-trained token embedding files, such as via instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`.
+
+
+ Parameters
+ ----------
+ counter : collections.Counter or None, default None
+ Counts text token frequencies in the text data. Its keys will be indexed
+ according to frequency thresholds such as `most_freq_count` and
+ `min_freq`. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+ token_embeddings : instance or list of :class:`~TokenEmbedding`
+ One or multiple pre-trained token embeddings to load. If it is a list of
+ multiple embeddings, these embedding vectors will be concatenated for
+ each token.
+ most_freq_count : None or int, default None
+ The maximum possible number of the most frequent tokens in the keys of
+ `counter` that can be indexed. Note that this argument does not count
+ any token from `reserved_tokens`. If this argument is None or larger
+ than its largest possible value restricted by `counter` and
+ `reserved_tokens`, this argument becomes positive infinity.
+ min_freq : int, default 1
+ The minimum frequency required for a token in the keys of `counter` to
+ be indexed.
+ unknown_token : hashable object, default '<unk>'
+ The representation for any unknown token. In other words, any unknown
+ token will be indexed as the same representation. Keys of `counter`,
+ `unknown_token`, and values of `reserved_tokens` must be of the same
+ hashable type. Examples: str, int, and tuple.
+ reserved_tokens : list of hashable objects or None, default None
+ A list of reserved tokens that will always be indexed, such as special
+ symbols representing padding, beginning of sentence, and end of
+ sentence. It cannot contain `unknown_token`, or duplicate reserved
+ tokens. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ vec_len : int
+ The length of the embedding vector for each token.
+ idx_to_vec : mxnet.ndarray.NDArray
+ For all the indexed tokens in this embedding, this NDArray maps each
+ token's index to an embedding vector. The largest valid index maps
+ to the initialized embedding vector for every reserved token, such as an
+ unknown_token token and a padding token.
+ """
+ def __init__(self, counter, token_embeddings, most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None):
+
+ if not isinstance(token_embeddings, list):
+ token_embeddings = [token_embeddings]
+
+ # Sanity checks.
+ for embed in token_embeddings:
+ assert isinstance(embed, TokenEmbedding), \
+ 'The parameter `token_embeddings` must be an instance or a ' \
+ 'list of instances of `mxnet.text.embedding.TextEmbed` ' \
+ 'whose embedding vectors will be loaded or ' \
+ 'concatenated-then-loaded to map to the indexed tokens.'
+
+ # Index tokens from keys of `counter` and reserved tokens.
+ super(Glossary, self).__init__(counter=counter,
+ most_freq_count=most_freq_count,
+ min_freq=min_freq,
+ unknown_token=unknown_token,
+ reserved_tokens=reserved_tokens)
+
+ # Set _idx_to_vec so that indices of tokens from keys of `counter` are
+ # associated with token embedding vectors from `token_embeddings`.
+ self._set_idx_to_vec_by_embeds(token_embeddings)
+
+ def _set_idx_to_vec_by_embeds(self, token_embeddings):
+ """Sets the mapping between token indices and token embedding vectors.
+
+
+ Parameters
+ ----------
+ token_embeddings : an instance or a list of instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`
+ One or multiple pre-trained token embeddings to load. If it is a
+ list of multiple embeddings, these embedding vectors will be
+ concatenated for each token.
+ """
+
+ self._vec_len = sum(embed.vec_len for embed in token_embeddings)
+ self._idx_to_vec = nd.zeros(shape=(len(self), self.vec_len))
+
+ col_start = 0
+ # Concatenate all the embedding vectors in token_embeddings.
+ for embed in token_embeddings:
+ col_end = col_start + embed.vec_len
+ # Cancatenate vectors of the unknown token.
+ self._idx_to_vec[0, col_start:col_end] = embed.idx_to_vec[0]
+ self._idx_to_vec[1:, col_start:col_end] = embed.get_vecs_by_tokens(
+ self.idx_to_token[1:])
+ col_start = col_end
diff --git a/python/mxnet/contrib/text/indexer.py b/python/mxnet/contrib/text/indexer.py
new file mode 100644
index 0000000..bed2794
--- /dev/null
+++ b/python/mxnet/contrib/text/indexer.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=consider-iterating-dictionary
+
+"""Text token indexer."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+
+from . import constants as C
+
+
+class TokenIndexer(object):
+ """Indexing for text tokens.
+
+
+ Build indices for the unknown token, reserved tokens, and input counter
+ keys. Indexed tokens can be used by instances of
+ :class:`~mxnet.text.embedding.TokenEmbedding`, such as instances of
+ :class:`~mxnet.text.glossary.Glossary`.
+
+
+ Parameters
+ ----------
+ counter : collections.Counter or None, default None
+ Counts text token frequencies in the text data. Its keys will be indexed
+ according to frequency thresholds such as `most_freq_count` and
+ `min_freq`. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+ most_freq_count : None or int, default None
+ The maximum possible number of the most frequent tokens in the keys of
+ `counter` that can be indexed. Note that this argument does not count
+ any token from `reserved_tokens`. Suppose that there are different
+ keys of `counter` whose frequency are the same, if indexing all of them
+ will exceed this argument value, such keys will be indexed one by one
+ according to their __cmp__() order until the frequency threshold is
+ met. If this argument is None or larger than its largest possible value
+ restricted by `counter` and `reserved_tokens`, this argument has no
+ effect.
+ min_freq : int, default 1
+ The minimum frequency required for a token in the keys of `counter` to
+ be indexed.
+ unknown_token : hashable object, default '<unk>'
+ The representation for any unknown token. In other words, any unknown
+ token will be indexed as the same representation. Keys of `counter`,
+ `unknown_token`, and values of `reserved_tokens` must be of the same
+ hashable type. Examples: str, int, and tuple.
+ reserved_tokens : list of hashable objects or None, default None
+ A list of reserved tokens that will always be indexed, such as special
+ symbols representing padding, beginning of sentence, and end of
+ sentence. It cannot contain `unknown_token`, or duplicate reserved
+ tokens. Keys of `counter`, `unknown_token`, and values of
+ `reserved_tokens` must be of the same hashable type. Examples: str, int,
+ and tuple.
+
+
+ Properties
+ ----------
+ token_to_idx : dict mapping str to int
+ A dict mapping each token to its index integer.
+ idx_to_token : list of strs
+ A list of indexed tokens where the list indices and the token indices
+ are aligned.
+ unknown_token : hashable object
+ The representation for any unknown token. In other words, any
+ unknown token will be indexed as the same representation.
+ reserved_tokens : list of strs or None
+ A list of reserved tokens that will always be indexed.
+ """
+
+ def __init__(self, counter=None, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None):
+
+ # Sanity checks.
+ assert min_freq > 0, '`min_freq` must be set to a positive value.'
+
+ if reserved_tokens is not None:
+ reserved_token_set = set(reserved_tokens)
+ assert unknown_token not in reserved_token_set, \
+ '`reserved_token` cannot contain `unknown_token`.'
+ assert len(reserved_token_set) == len(reserved_tokens), \
+ '`reserved_tokens` cannot contain duplicate reserved tokens.'
+
+ self._index_unknown_and_reserved_tokens(unknown_token, reserved_tokens)
+
+ if counter is not None:
+ self._index_counter_keys(counter, unknown_token, reserved_tokens,
+ most_freq_count, min_freq)
+
+ def _index_unknown_and_reserved_tokens(self, unknown_token,
+ reserved_tokens):
+ """Indexes unknown and reserved tokens."""
+
+ self._unknown_token = unknown_token
+ # Thus, constants.UNKNOWN_IDX must be 0.
+ self._idx_to_token = [unknown_token]
+
+ if reserved_tokens is None:
+ self._reserved_tokens = None
+ else:
+ self._reserved_tokens = reserved_tokens[:]
+ self._idx_to_token.extend(reserved_tokens)
+
+ self._token_to_idx = {token: idx for idx, token in
+ enumerate(self._idx_to_token)}
+
+ def _index_counter_keys(self, counter, unknown_token, reserved_tokens,
+ most_freq_count, min_freq):
+ """Indexes keys of `counter`.
+
+
+ Indexes keys of `counter` according to frequency thresholds such as
+ `most_freq_count` and `min_freq`.
+ """
+
+ assert isinstance(counter, Counter), \
+ '`counter` must be an instance of collections.Counter.'
+
+ unknown_and_reserved_tokens = set(reserved_tokens) \
+ if reserved_tokens is not None else set()
+ unknown_and_reserved_tokens.add(unknown_token)
+
+ token_freqs = sorted(counter.items(), key=lambda x: x[0])
+ token_freqs.sort(key=lambda x: x[1], reverse=True)
+
+ token_cap = len(unknown_and_reserved_tokens) + (
+ len(counter) if most_freq_count is None else most_freq_count)
+
+ for token, freq in token_freqs:
+ if freq < min_freq or len(self._idx_to_token) == token_cap:
+ break
+ if token not in unknown_and_reserved_tokens:
+ self._idx_to_token.append(token)
+ self._token_to_idx[token] = len(self._idx_to_token) - 1
+
+ def __len__(self):
+ return len(self.idx_to_token)
+
+ @property
+ def token_to_idx(self):
+ return self._token_to_idx
+
+ @property
+ def idx_to_token(self):
+ return self._idx_to_token
+
+ @property
+ def unknown_token(self):
+ return self._unknown_token
+
+ @property
+ def reserved_tokens(self):
+ return self._reserved_tokens
+
+ def to_indices(self, tokens):
+ """Converts tokens to indices according to the text indexer.
+
+
+ Parameters
+ ----------
+ tokens : str or list of strs
+ A source token or tokens to be converted.
+
+
+ Returns
+ -------
+ int or list of ints
+ A token index or a list of token indices according to the text
+ indexer.
+ """
+
+ to_reduce = False
+ if not isinstance(tokens, list):
+ tokens = [tokens]
+ to_reduce = True
+
+ indices = [self.token_to_idx[token] if token in self.token_to_idx
+ else C.UNKNOWN_IDX for token in tokens]
+
+ return indices[0] if to_reduce else indices
+
+ def to_tokens(self, indices):
+ """Converts token indices to tokens according to the text indexer.
+
+
+ Parameters
+ ----------
+ indices : int or list of ints
+ A source token index or token indices to be converted.
+
+
+ Returns
+ -------
+ str or list of strs
+ A token or a list of tokens according to the text indexer.
+ """
+
+ to_reduce = False
+ if not isinstance(indices, list):
+ indices = [indices]
+ to_reduce = True
+
+ max_idx = len(self.idx_to_token) - 1
+
+ tokens = []
+ for idx in indices:
+ if not isinstance(idx, int) or idx > max_idx:
+ raise ValueError('Token index %d in the provided `indices` is '
+ 'invalid.' % idx)
+ else:
+ tokens.append(self.idx_to_token[idx])
+
+ return tokens[0] if to_reduce else tokens
diff --git a/python/mxnet/contrib/text/utils.py b/python/mxnet/contrib/text/utils.py
new file mode 100644
index 0000000..91e1b62
--- /dev/null
+++ b/python/mxnet/contrib/text/utils.py
@@ -0,0 +1,79 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+"""Provide utilities for text data processing."""
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+import re
+
+
+def count_tokens_from_str(source_str, token_delim=' ', seq_delim='\n',
+ to_lower=False, counter_to_update=None):
+ """Counts tokens in the specified string.
+
+ For token_delim='<td>' and seq_delim='<sd>', a specified string of two
+ sequences of tokens may look like::
+
+ <td>token1<td>token2<td>token3<td><sd><td>token4<td>token5<td><sd>
+
+
+ Parameters
+ ----------
+ source_str : str
+ A source string of tokens.
+ token_delim : str, default ' '
+ A token delimiter.
+ seq_delim : str, default '\\\\n'
+ A sequence delimiter.
+ to_lower : bool, default False
+ Whether to convert the source source_str to the lower case.
+ counter_to_update : collections.Counter or None, default None
+ The collections.Counter instance to be updated with the token counts
+ of `source_str`. If None, return a new collections.Counter instance
+ counting tokens from `source_str`.
+
+
+ Returns
+ -------
+ collections.Counter
+ The `counter_to_update` collections.Counter instance after being updated
+ with the token counts of `source_str`. If `counter_to_update` is None,
+ return a new collections.Counter instance counting tokens from
+ `source_str`.
+
+
+ Examples
+ --------
+ >>> source_str = ' Life is great ! \\n life is good . \\n'
+ >>> count_tokens_from_str(token_line, ' ', '\\n', True)
+ Counter({'!': 1, '.': 1, 'good': 1, 'great': 1, 'is': 2, 'life': 2})
+ """
+
+ source_str = filter(None,
+ re.split(token_delim + '|' + seq_delim, source_str))
+ if to_lower:
+ source_str = [t.lower() for t in source_str]
+
+ if counter_to_update is None:
+ return Counter(source_str)
+ else:
+ counter_to_update.update(source_str)
+ return counter_to_update
diff --git a/python/mxnet/registry.py b/python/mxnet/registry.py
index 4a4f22f..4c131a1 100644
--- a/python/mxnet/registry.py
+++ b/python/mxnet/registry.py
@@ -29,6 +29,23 @@ from .base import string_types
_REGISTRY = {}
+def get_registry(base_class):
+ """Get a copy of the registry.
+
+ Parameters
+ ----------
+ base_class : type
+ base class for classes that will be registered.
+
+ Returns
+ -------
+ a registrator
+ """
+ if base_class not in _REGISTRY:
+ _REGISTRY[base_class] = {}
+ return _REGISTRY[base_class].copy()
+
+
def get_register_func(base_class, nickname):
"""Get registrator function.
diff --git a/tests/python/unittest/test_contrib_text.py b/tests/python/unittest/test_contrib_text.py
new file mode 100644
index 0000000..f666888
--- /dev/null
+++ b/tests/python/unittest/test_contrib_text.py
@@ -0,0 +1,727 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# 'License'); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+
+from __future__ import absolute_import
+from __future__ import print_function
+
+from collections import Counter
+import unittest
+
+from common import assertRaises
+from mxnet import ndarray as nd
+from mxnet.test_utils import *
+from mxnet.contrib.text import utils
+from mxnet.contrib.text.glossary import Glossary
+from mxnet.contrib.text.indexer import TokenIndexer
+from mxnet.contrib.text.embedding import TokenEmbedding, CustomEmbedding
+
+
+def _get_test_str_of_tokens(token_delim, seq_delim):
+ seq1 = token_delim + token_delim.join(['Life', 'is', 'great', '!']) \
+ + token_delim + seq_delim
+ seq2 = token_delim + token_delim.join(['life', 'is', 'good', '.']) \
+ + token_delim + seq_delim
+ seq3 = token_delim + token_delim.join(['life', "isn't", 'bad', '.']) \
+ + token_delim + seq_delim
+ seqs = seq1 + seq2 + seq3
+ return seqs
+
+
+def _test_count_tokens_from_str_with_delims(token_delim, seq_delim):
+ source_str = _get_test_str_of_tokens(token_delim, seq_delim)
+
+ cnt1 = utils.count_tokens_from_str(source_str, token_delim, seq_delim,
+ to_lower=False)
+ assert cnt1 == Counter(
+ {'is': 2, 'life': 2, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ cnt2 = utils.count_tokens_from_str(source_str, token_delim, seq_delim,
+ to_lower=True)
+ assert cnt2 == Counter(
+ {'life': 3, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ counter_to_update = Counter({'life': 2})
+
+ cnt3 = utils.count_tokens_from_str(
+ source_str, token_delim, seq_delim, to_lower=False,
+ counter_to_update=counter_to_update.copy())
+ assert cnt3 == Counter(
+ {'is': 2, 'life': 4, '.': 2, 'Life': 1, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+ cnt4 = utils.count_tokens_from_str(
+ source_str, token_delim, seq_delim, to_lower=True,
+ counter_to_update=counter_to_update.copy())
+ assert cnt4 == Counter(
+ {'life': 5, 'is': 2, '.': 2, 'great': 1, '!': 1, 'good': 1,
+ "isn't": 1, 'bad': 1})
+
+
+def test_count_tokens_from_str():
+ _test_count_tokens_from_str_with_delims(' ', '\n')
+ _test_count_tokens_from_str_with_delims('IS', 'LIFE')
+
+
+def test_tokens_to_indices():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ indexer = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+
+ i1 = indexer.to_indices('c')
+ assert i1 == 1
+
+ i2 = indexer.to_indices(['c'])
+ assert i2 == [1]
+
+ i3 = indexer.to_indices(['<unk>', 'non-exist'])
+ assert i3 == [0, 0]
+
+ i4 = indexer.to_indices(['a', 'non-exist', 'a', 'b'])
+ assert i4 == [3, 0, 3, 2]
+
+
+def test_indices_to_tokens():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ indexer = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unknown>', reserved_tokens=None)
+
+ i1 = indexer.to_tokens(1)
+ assert i1 == 'c'
+
+ i2 = indexer.to_tokens([1])
+ assert i2 == ['c']
+
+ i3 = indexer.to_tokens([0, 0])
+ assert i3 == ['<unknown>', '<unknown>']
+
+ i4 = indexer.to_tokens([3, 0, 3, 2])
+ assert i4 == ['a', '<unknown>', 'a', 'b']
+
+ assertRaises(ValueError, indexer.to_tokens, 100)
+
+def test_download_embed():
+ @TokenEmbedding.register
+ class Test(TokenEmbedding):
+ pretrained_file_name_sha1 = \
+ {'embedding_test.vec': '29b9a6511cf4b5aae293c44a9ec1365b74f2a2f8'} # 33 bytes
+ namespace = 'test'
+
+ def __init__(self, embedding_root='embeddings',
+ init_unknown_vec=nd.zeros, **kwargs):
+ pretrained_file_name = 'embedding_test.vec'
+ Test._check_pretrained_file_names(pretrained_file_name)
+
+ super(Test, self).__init__(**kwargs)
+
+ pretrained_file_path = Test._get_pretrained_file(embedding_root,
+ pretrained_file_name)
+
+ self._load_embedding(pretrained_file_path, ' ', init_unknown_vec)
+
+ test_embed = TokenEmbedding.create('test')
+ assert test_embed.token_to_idx['hello'] == 1
+ assert test_embed.token_to_idx['world'] == 2
+ assert_almost_equal(test_embed.idx_to_vec[1].asnumpy(), (nd.arange(5) + 1).asnumpy())
+ assert_almost_equal(test_embed.idx_to_vec[2].asnumpy(), (nd.arange(5) + 6).asnumpy())
+ assert_almost_equal(test_embed.idx_to_vec[0].asnumpy(), nd.zeros((5,)).asnumpy())
+
+
+
+def _mk_my_pretrain_file(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seqs = seq1 + seq2
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file2(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04',
+ '0.05']) + '\n'
+ seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09', '0.1']) + '\n'
+ seqs = seq1 + seq2
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file3(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['<unk1>', '1.1', '1.2', '1.3', '1.4',
+ '1.5']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_pretrain_file4(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.01', '0.02', '0.03', '0.04',
+ '0.05']) + '\n'
+ seq2 = token_delim.join(['c', '0.06', '0.07', '0.08', '0.09',
+ '0.1']) + '\n'
+ seq3 = token_delim.join(['<unk2>', '0.11', '0.12', '0.13', '0.14',
+ '0.15']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_invalid_pretrain_file(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['c']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def _mk_my_invalid_pretrain_file2(path, token_delim, pretrain_file):
+ path = os.path.expanduser(path)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ seq1 = token_delim.join(['a', '0.1', '0.2', '0.3', '0.4', '0.5']) + '\n'
+ seq2 = token_delim.join(['b', '0.6', '0.7', '0.8', '0.9', '1.0']) + '\n'
+ seq3 = token_delim.join(['c', '0.6', '0.7', '0.8']) + '\n'
+ seqs = seq1 + seq2 + seq3
+ with open(os.path.join(path, pretrain_file), 'w') as fout:
+ fout.write(seqs)
+
+
+def test_custom_embed():
+ embed_root = 'embeddings'
+ embed_name = 'my_embed'
+ elem_delim = '\t'
+ pretrain_file = 'my_pretrain_file.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file)
+
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
+
+ my_embed = CustomEmbedding(pretrain_file_path, elem_delim)
+
+ assert len(my_embed) == 3
+ assert my_embed.vec_len == 5
+ assert my_embed.token_to_idx['a'] == 1
+ assert my_embed.idx_to_token[1] == 'a'
+
+ first_vec = my_embed.idx_to_vec[0]
+ assert_almost_equal(first_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
+
+ unk_vec = my_embed.get_vecs_by_tokens('A')
+ assert_almost_equal(unk_vec.asnumpy(), np.array([0, 0, 0, 0, 0]))
+
+ a_vec = my_embed.get_vecs_by_tokens('A', lower_case_backup=True)
+ assert_almost_equal(a_vec.asnumpy(), np.array([0.1, 0.2, 0.3, 0.4, 0.5]))
+
+ unk_vecs = my_embed.get_vecs_by_tokens(['<un...@unk>', '<un...@unk>'])
+ assert_almost_equal(unk_vecs.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0]]))
+
+ # Test loaded unknown vectors.
+ pretrain_file2 = 'my_pretrain_file2.txt'
+ _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file2)
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file2)
+ my_embed2 = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk>')
+ unk_vec2 = my_embed2.get_vecs_by_tokens('<unk>')
+ assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
+ unk_vec2 = my_embed2.get_vecs_by_tokens('<un...@unk>')
+ assert_almost_equal(unk_vec2.asnumpy(), np.array([1, 1, 1, 1, 1]))
+
+ my_embed3 = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk1>')
+ unk_vec3 = my_embed3.get_vecs_by_tokens('<unk1>')
+ assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
+ unk_vec3 = my_embed3.get_vecs_by_tokens('<un...@unk>')
+ assert_almost_equal(unk_vec3.asnumpy(), np.array([1.1, 1.2, 1.3, 1.4, 1.5]))
+
+ # Test error handling.
+ invalid_pretrain_file = 'invalid_pretrain_file.txt'
+ _mk_my_invalid_pretrain_file(os.path.join(embed_root, embed_name),
+ elem_delim, invalid_pretrain_file)
+ pretrain_file_path = os.path.join(embed_root, embed_name,
+ invalid_pretrain_file)
+ assertRaises(AssertionError, CustomEmbedding, pretrain_file_path,
+ elem_delim)
+
+ invalid_pretrain_file2 = 'invalid_pretrain_file2.txt'
+ _mk_my_invalid_pretrain_file2(os.path.join(embed_root, embed_name),
+ elem_delim, invalid_pretrain_file2)
+ pretrain_file_path = os.path.join(embed_root, embed_name,
+ invalid_pretrain_file2)
+ assertRaises(AssertionError, CustomEmbedding, pretrain_file_path,
+ elem_delim)
+
+
+def test_token_indexer():
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g1) == 5
+ assert g1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g1.idx_to_token[1] == 'c'
+ assert g1.unknown_token == '<unk>'
+ assert g1.reserved_tokens is None
+
+ g2 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g2) == 3
+ assert g2.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+ assert g2.idx_to_token[1] == 'c'
+ assert g2.unknown_token == '<unk>'
+ assert g2.reserved_tokens is None
+
+ g3 = TokenIndexer(counter, most_freq_count=None, min_freq=100,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g3) == 1
+ assert g3.token_to_idx == {'<unk>': 0}
+ assert g3.idx_to_token[0] == '<unk>'
+ assert g3.unknown_token == '<unk>'
+ assert g3.reserved_tokens is None
+
+ g4 = TokenIndexer(counter, most_freq_count=2, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g4) == 3
+ assert g4.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2}
+ assert g4.idx_to_token[1] == 'c'
+ assert g4.unknown_token == '<unk>'
+ assert g4.reserved_tokens is None
+
+ g5 = TokenIndexer(counter, most_freq_count=3, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g5) == 4
+ assert g5.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3}
+ assert g5.idx_to_token[1] == 'c'
+ assert g5.unknown_token == '<unk>'
+ assert g5.reserved_tokens is None
+
+ g6 = TokenIndexer(counter, most_freq_count=100, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g6) == 5
+ assert g6.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g6.idx_to_token[1] == 'c'
+ assert g6.unknown_token == '<unk>'
+ assert g6.reserved_tokens is None
+
+ g7 = TokenIndexer(counter, most_freq_count=1, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=None)
+ assert len(g7) == 2
+ assert g7.token_to_idx == {'<unk>': 0, 'c': 1}
+ assert g7.idx_to_token[1] == 'c'
+ assert g7.unknown_token == '<unk>'
+ assert g7.reserved_tokens is None
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=0, unknown_token='<unknown>',
+ reserved_tokens=['b'])
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=1, unknown_token='<unknown>',
+ reserved_tokens=['b', 'b'])
+
+ assertRaises(AssertionError, TokenIndexer, counter, most_freq_count=None,
+ min_freq=1, unknown_token='<unknown>',
+ reserved_tokens=['b', '<unknown>'])
+
+ g8 = TokenIndexer(counter, most_freq_count=None, min_freq=1,
+ unknown_token='<unknown>', reserved_tokens=['b'])
+ assert len(g8) == 5
+ assert g8.token_to_idx == {'<unknown>': 0, 'b': 1, 'c': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g8.idx_to_token[1] == 'b'
+ assert g8.unknown_token == '<unknown>'
+ assert g8.reserved_tokens == ['b']
+
+ g9 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=['b', 'a'])
+ assert len(g9) == 4
+ assert g9.token_to_idx == {'<unk>': 0, 'b': 1, 'a': 2, 'c': 3}
+ assert g9.idx_to_token[1] == 'b'
+ assert g9.unknown_token == '<unk>'
+ assert g9.reserved_tokens == ['b', 'a']
+
+ g10 = TokenIndexer(counter, most_freq_count=None, min_freq=100,
+ unknown_token='<unk>', reserved_tokens=['b', 'c'])
+ assert len(g10) == 3
+ assert g10.token_to_idx == {'<unk>': 0, 'b': 1, 'c': 2}
+ assert g10.idx_to_token[1] == 'b'
+ assert g10.unknown_token == '<unk>'
+ assert g10.reserved_tokens == ['b', 'c']
+
+ g11 = TokenIndexer(counter, most_freq_count=1, min_freq=2,
+ unknown_token='<unk>', reserved_tokens=['<pad>', 'b'])
+ assert len(g11) == 4
+ assert g11.token_to_idx == {'<unk>': 0, '<pad>': 1, 'b': 2, 'c': 3}
+ assert g11.idx_to_token[1] == '<pad>'
+ assert g11.unknown_token == '<unk>'
+ assert g11.reserved_tokens == ['<pad>', 'b']
+
+ g12 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='b', reserved_tokens=['<pad>'])
+ assert len(g12) == 3
+ assert g12.token_to_idx == {'b': 0, '<pad>': 1, 'c': 2}
+ assert g12.idx_to_token[1] == '<pad>'
+ assert g12.unknown_token == 'b'
+ assert g12.reserved_tokens == ['<pad>']
+
+ g13 = TokenIndexer(counter, most_freq_count=None, min_freq=2,
+ unknown_token='a', reserved_tokens=['<pad>'])
+ assert len(g13) == 4
+ assert g13.token_to_idx == {'a': 0, '<pad>': 1, 'c': 2, 'b': 3}
+ assert g13.idx_to_token[1] == '<pad>'
+ assert g13.unknown_token == 'a'
+ assert g13.reserved_tokens == ['<pad>']
+
+ counter_tuple = Counter([('a', 'a'), ('b', 'b'), ('b', 'b'),
+ ('c', 'c'), ('c', 'c'), ('c', 'c'),
+ ('some_word$', 'some_word$')])
+
+ g14 = TokenIndexer(counter_tuple, most_freq_count=None, min_freq=1,
+ unknown_token=('<unk>', '<unk>'), reserved_tokens=None)
+ assert len(g14) == 5
+ assert g14.token_to_idx == {('<unk>', '<unk>'): 0, ('c', 'c'): 1,
+ ('b', 'b'): 2, ('a', 'a'): 3,
+ ('some_word$', 'some_word$'): 4}
+ assert g14.idx_to_token[1] == ('c', 'c')
+ assert g14.unknown_token == ('<unk>', '<unk>')
+ assert g14.reserved_tokens is None
+
+
+def test_glossary_with_one_embed():
+ embed_root = 'embeddings'
+ embed_name = 'my_embed'
+ elem_delim = '\t'
+ pretrain_file = 'my_pretrain_file1.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file)
+
+ pretrain_file_path = os.path.join(embed_root, embed_name, pretrain_file)
+
+ my_embed = CustomEmbedding(pretrain_file_path, elem_delim,
+ init_unknown_vec=nd.ones)
+
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = Glossary(counter, my_embed, most_freq_count=None, min_freq=1,
+ unknown_token='<unk>', reserved_tokens=['<pad>'])
+
+ assert g1.token_to_idx == {'<unk>': 0, '<pad>': 1, 'c': 2, 'b': 3, 'a': 4,
+ 'some_word$': 5}
+ assert g1.idx_to_token == ['<unk>', '<pad>', 'c', 'b', 'a', 'some_word$']
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [0.6, 0.7, 0.8, 0.9, 1],
+ [0.1, 0.2, 0.3, 0.4, 0.5],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assert g1.vec_len == 5
+ assert g1.reserved_tokens == ['<pad>']
+
+ assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+ np.array([1, 1, 1, 1, 1])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['c']).asnumpy(),
+ np.array([[1, 1, 1, 1, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['a', 'not_exist']).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['a', 'b']).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b']).asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['A', 'b'],
+ lower_case_backup=True).asnumpy(),
+ np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
+ [0.6, 0.7, 0.8, 0.9, 1]])
+ )
+
+ g1.update_token_vectors(['a', 'b'],
+ nd.array([[2, 2, 2, 2, 2],
+ [3, 3, 3, 3, 3]])
+ )
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+
+ assertRaises(ValueError, g1.update_token_vectors, 'unknown$$$',
+ nd.array([0, 0, 0, 0, 0]))
+
+ assertRaises(AssertionError, g1.update_token_vectors, '<unk>',
+ nd.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
+
+ assertRaises(AssertionError, g1.update_token_vectors, '<unk>',
+ nd.array([0]))
+
+ g1.update_token_vectors(['<unk>'],
+ nd.array([0, 0, 0, 0, 0])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors(['<unk>'],
+ nd.array([[10, 10, 10, 10, 10]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[10, 10, 10, 10, 10],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors('<unk>',
+ nd.array([0, 0, 0, 0, 0])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+ g1.update_token_vectors('<unk>',
+ nd.array([[10, 10, 10, 10, 10]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[10, 10, 10, 10, 10],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1]])
+ )
+
+
+def test_glossary_with_two_embeds():
+ embed_root = '.'
+ embed_name = 'my_embed'
+ elem_delim = '\t'
+ pretrain_file1 = 'my_pretrain_file1.txt'
+ pretrain_file2 = 'my_pretrain_file2.txt'
+
+ _mk_my_pretrain_file(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file1)
+ _mk_my_pretrain_file2(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file2)
+
+ pretrain_file_path1 = os.path.join(embed_root, embed_name, pretrain_file1)
+ pretrain_file_path2 = os.path.join(embed_root, embed_name, pretrain_file2)
+
+ my_embed1 = CustomEmbedding(pretrain_file_path1, elem_delim,
+ init_unknown_vec=nd.ones)
+ my_embed2 = CustomEmbedding(pretrain_file_path2, elem_delim)
+
+ counter = Counter(['a', 'b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g1 = Glossary(counter, [my_embed1, my_embed2], most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None)
+
+ assert g1.token_to_idx == {'<unk>': 0, 'c': 1, 'b': 2, 'a': 3,
+ 'some_word$': 4}
+ assert g1.idx_to_token == ['<unk>', 'c', 'b', 'a', 'some_word$']
+
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ assert g1.vec_len == 10
+ assert g1.reserved_tokens is None
+ assert_almost_equal(g1.get_vecs_by_tokens('c').asnumpy(),
+ np.array([1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1])
+ )
+
+ assert_almost_equal(g1.get_vecs_by_tokens(['b', 'not_exist']).asnumpy(),
+ np.array([[0.6, 0.7, 0.8, 0.9, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ g1.update_token_vectors(['a', 'b'],
+ nd.array([[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
+ [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
+ )
+ assert_almost_equal(g1.idx_to_vec.asnumpy(),
+ np.array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 0.06, 0.07, 0.08, 0.09, 0.1],
+ [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
+ [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
+ [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
+ )
+
+ # Test loaded unknown tokens
+ pretrain_file3 = 'my_pretrain_file3.txt'
+ pretrain_file4 = 'my_pretrain_file4.txt'
+
+ _mk_my_pretrain_file3(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file3)
+ _mk_my_pretrain_file4(os.path.join(embed_root, embed_name), elem_delim,
+ pretrain_file4)
+
+ pretrain_file_path3 = os.path.join(embed_root, embed_name, pretrain_file3)
+ pretrain_file_path4 = os.path.join(embed_root, embed_name, pretrain_file4)
+
+ my_embed3 = CustomEmbedding(pretrain_file_path3, elem_delim,
+ init_unknown_vec=nd.ones,
+ unknown_token='<unk1>')
+ my_embed4 = CustomEmbedding(pretrain_file_path4, elem_delim,
+ unknown_token='<unk2>')
+
+ g2 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk>', reserved_tokens=None)
+ assert_almost_equal(g2.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ g3 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk1>', reserved_tokens=None)
+ assert_almost_equal(g3.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ g4 = Glossary(counter, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='<unk2>', reserved_tokens=None)
+ assert_almost_equal(g4.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [0.1, 0.2, 0.3, 0.4, 0.5,
+ 0.01, 0.02, 0.03, 0.04, 0.05],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+ counter2 = Counter(['b', 'b', 'c', 'c', 'c', 'some_word$'])
+
+ g5 = Glossary(counter2, [my_embed3, my_embed4], most_freq_count=None,
+ min_freq=1, unknown_token='a', reserved_tokens=None)
+ assert g5.token_to_idx == {'a': 0, 'c': 1, 'b': 2, 'some_word$': 3}
+ assert g5.idx_to_token == ['a', 'c', 'b', 'some_word$']
+ assert_almost_equal(g5.idx_to_vec.asnumpy(),
+ np.array([[1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.06, 0.07, 0.08, 0.09, 0.1],
+ [0.6, 0.7, 0.8, 0.9, 1,
+ 0.11, 0.12, 0.13, 0.14, 0.15],
+ [1.1, 1.2, 1.3, 1.4, 1.5,
+ 0.11, 0.12, 0.13, 0.14, 0.15]])
+ )
+
+
+def test_get_embedding_names_and_pretrain_files():
+ assert len(TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name='fasttext')) == 294
+
+ assert len(TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name='glove')) == 10
+
+ reg = TokenEmbedding.get_embedding_and_pretrained_file_names(
+ embedding_name=None)
+
+ assert len(reg['glove']) == 10
+ assert len(reg['fasttext']) == 294
+
+ assertRaises(KeyError,
+ TokenEmbedding.get_embedding_and_pretrained_file_names,
+ 'unknown$$')
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].