You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/17 22:24:44 UTC

[incubator-mxnet] branch master updated: allow user to define unknown token symbol (#10461)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0910450  allow user to define unknown token symbol (#10461)
0910450 is described below

commit 0910450110c37da9f052f3b29c40c6d051f46a6a
Author: Cong <18...@users.noreply.github.com>
AuthorDate: Sun Jun 17 23:24:38 2018 +0100

    allow user to define unknown token symbol (#10461)
    
    test case added
---
 python/mxnet/rnn/io.py            | 10 ++++++++--
 tests/python/unittest/test_rnn.py | 10 +++++++++-
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/rnn/io.py b/python/mxnet/rnn/io.py
index 8eba9d2..a8890c9 100644
--- a/python/mxnet/rnn/io.py
+++ b/python/mxnet/rnn/io.py
@@ -27,7 +27,8 @@ import numpy as np
 from ..io import DataIter, DataBatch, DataDesc
 from .. import ndarray
 
-def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0):
+def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
+                     start_label=0, unknown_token=None):
     """Encode sentences and (optionally) build a mapping
     from string tokens to integer indices. Unknown keys
     will be added to vocabulary.
@@ -46,6 +47,9 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
         of sentence by default.
     start_label : int
         lowest index.
+    unknown_token: str
+        Symbol to represent unknown token.
+        If not specified, unknown token will be skipped.
 
     Returns
     -------
@@ -65,9 +69,11 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
         coded = []
         for word in sent:
             if word not in vocab:
-                assert new_vocab, "Unknown token %s"%word
+                assert (new_vocab or unknown_token), "Unknown token %s"%word
                 if idx == invalid_label:
                     idx += 1
+                if unknown_token:
+                    word = unknown_token
                 vocab[word] = idx
                 idx += 1
             coded.append(vocab[word])
diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py
index 52a3dcf..a558825 100644
--- a/tests/python/unittest/test_rnn.py
+++ b/tests/python/unittest/test_rnn.py
@@ -300,7 +300,15 @@ def test_convgru():
     args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
     assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
 
+def test_encode_sentences():
+    sentences = [['a','b','c'],['b','c','d']]
+    dict = {'a':1, 'b':2, 'c':3}
+    result, vocab = mx.rnn.io.encode_sentences(sentences, vocab=dict, invalid_label=-1, invalid_key='\n',
+                         start_label=0, unknown_token='UNK')
+    print(result, vocab)
+    assert vocab == {'a': 1, 'b': 2, 'c': 3, 'UNK': 0}
+    assert result == [[1,2,3],[2,3,0]]
+    
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.