You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/17 22:24:39 UTC

[GitHub] szha closed pull request #10461: allow user to define unknown token symbol

szha closed pull request #10461: allow user to define unknown token symbol
URL: https://github.com/apache/incubator-mxnet/pull/10461
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/rnn/io.py b/python/mxnet/rnn/io.py
index 8eba9d21e39..a8890c9e7e7 100644
--- a/python/mxnet/rnn/io.py
+++ b/python/mxnet/rnn/io.py
@@ -27,7 +27,8 @@
 from ..io import DataIter, DataBatch, DataDesc
 from .. import ndarray
 
-def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0):
+def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
+                     start_label=0, unknown_token=None):
     """Encode sentences and (optionally) build a mapping
     from string tokens to integer indices. Unknown keys
     will be added to vocabulary.
@@ -46,6 +47,9 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
         of sentence by default.
     start_label : int
         lowest index.
+    unknown_token: str
+        Symbol to represent unknown token.
+        If not specified, unknown token will be skipped.
 
     Returns
     -------
@@ -65,9 +69,11 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
         coded = []
         for word in sent:
             if word not in vocab:
-                assert new_vocab, "Unknown token %s"%word
+                assert (new_vocab or unknown_token), "Unknown token %s"%word
                 if idx == invalid_label:
                     idx += 1
+                if unknown_token:
+                    word = unknown_token
                 vocab[word] = idx
                 idx += 1
             coded.append(vocab[word])
diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py
index 9fe22ae72df..2311625d3ca 100644
--- a/tests/python/unittest/test_rnn.py
+++ b/tests/python/unittest/test_rnn.py
@@ -296,7 +296,15 @@ def test_convgru():
     args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
     assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
 
+def test_encode_sentences():
+    sentences = [['a','b','c'],['b','c','d']]
+    dict = {'a':1, 'b':2, 'c':3}
+    result, vocab = mx.rnn.io.encode_sentences(sentences, vocab=dict, invalid_label=-1, invalid_key='\n',
+                         start_label=0, unknown_token='UNK')
+    print(result, vocab)
+    assert vocab == {'a': 1, 'b': 2, 'c': 3, 'UNK': 0}
+    assert result == [[1,2,3],[2,3,0]]
+    
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services