You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/17 22:24:39 UTC
[GitHub] szha closed pull request #10461: allow user to define unknown token
symbol
szha closed pull request #10461: allow user to define unknown token symbol
URL: https://github.com/apache/incubator-mxnet/pull/10461
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/rnn/io.py b/python/mxnet/rnn/io.py
index 8eba9d21e39..a8890c9e7e7 100644
--- a/python/mxnet/rnn/io.py
+++ b/python/mxnet/rnn/io.py
@@ -27,7 +27,8 @@
from ..io import DataIter, DataBatch, DataDesc
from .. import ndarray
-def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0):
+def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
+ start_label=0, unknown_token=None):
"""Encode sentences and (optionally) build a mapping
from string tokens to integer indices. Unknown keys
will be added to vocabulary.
@@ -46,6 +47,9 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
of sentence by default.
start_label : int
lowest index.
+ unknown_token: str
+ Symbol to represent unknown token.
+ If not specified, unknown token will be skipped.
Returns
-------
@@ -65,9 +69,11 @@ def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n',
coded = []
for word in sent:
if word not in vocab:
- assert new_vocab, "Unknown token %s"%word
+ assert (new_vocab or unknown_token), "Unknown token %s"%word
if idx == invalid_label:
idx += 1
+ if unknown_token:
+ word = unknown_token
vocab[word] = idx
idx += 1
coded.append(vocab[word])
diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py
index 9fe22ae72df..2311625d3ca 100644
--- a/tests/python/unittest/test_rnn.py
+++ b/tests/python/unittest/test_rnn.py
@@ -296,7 +296,15 @@ def test_convgru():
args, outs, auxs = outputs.infer_shape(rnn_t0_data=(1, 3, 16, 10), rnn_t1_data=(1, 3, 16, 10), rnn_t2_data=(1, 3, 16, 10))
assert outs == [(1, 10, 16, 10), (1, 10, 16, 10), (1, 10, 16, 10)]
+def test_encode_sentences():
+ sentences = [['a','b','c'],['b','c','d']]
+ dict = {'a':1, 'b':2, 'c':3}
+ result, vocab = mx.rnn.io.encode_sentences(sentences, vocab=dict, invalid_label=-1, invalid_key='\n',
+ start_label=0, unknown_token='UNK')
+ print(result, vocab)
+ assert vocab == {'a': 1, 'b': 2, 'c': 3, 'UNK': 0}
+ assert result == [[1,2,3],[2,3,0]]
+
if __name__ == '__main__':
import nose
nose.runmodule()
-
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services