You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/18 07:52:20 UTC

[GitHub] mengjiexu commented on a change in pull request #12585: add speech recognition using gluon

mengjiexu commented on a change in pull request #12585: add speech recognition using gluon
URL: https://github.com/apache/incubator-mxnet/pull/12585#discussion_r218329925
 
 

 ##########
 File path: example/gluon/gluon_speech_recognition/HybridSequential_CNN_CTC.py
 ##########
 @@ -0,0 +1,235 @@
+import mxnet as mx
+from mxnet import gluon, nd, autograd, init
+from mxnet.gluon import loss
+from mxnet.gluon import nn
+import numpy as  np
+import os
+import sys
+import gluonbook as gb
+
+
+class Resnet1D(nn.HybridBlock):
+    def __init__(self, num_channels, **kwargs):
+        super(Resnet1D, self).__init__(**kwargs)
+        self.conv1 = nn.Conv1D(num_channels, kernel_size=3, padding=1)
+        self.conv2 = nn.Conv1D(num_channels, kernel_size=3, padding=1)
+        self.bn1 = nn.BatchNorm()
+        self.bn2 = nn.BatchNorm()
+
+    # def forward(self, x):
+    #     y = nd.relu(self.bn1(self.conv1(x)))
+    #     y = self.bn2(self.conv2(y))
+    #     return nd.relu(y + x)
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        y = F.relu(self.bn1(self.conv1(x)))
+        y = self.bn2(self.conv2(y))
+        return F.relu(y + x)
+
+
+class CBR(nn.HybridBlock):
+    def __init__(self, num_channels, kernel_size=3, padding=1, **kwargs):
+        super(CBR, self).__init__(**kwargs)
+        self.conv = nn.Conv1D(num_channels, kernel_size=kernel_size, padding=padding)
+        self.bn = nn.BatchNorm()
+        self.relu = nn.Activation('relu')
+
+    # def forward(self, x):
+    #     return self.relu(self.bn(self.conv(x)))
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        return F.relu(self.bn(self.conv(x)))
+
+
+class SwapAxes(nn.HybridBlock):
+    def __init__(self, dim1, dim2):
+        super(SwapAxes, self).__init__()
+        self.dim1 = dim1
+        self.dim2 = dim2
+
+    # def forward(self, x):
+    #     return nd.swapaxes(x, self.dim1, self.dim2)
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        return F.swapaxes(x, self.dim1, self.dim2)
+
+
+with mx.Context(mx.cpu(0)):
+    model = nn.HybridSequential()
+    model.add(SwapAxes(1,2),
+              CBR(40, 1),
+              CBR(40),
+              CBR(40),
+              nn.MaxPool1D(2),
+              CBR(80, 1),
+              CBR(80),
+              CBR(80),
+              nn.MaxPool1D(2),
+              CBR(160, 1),
+              nn.Dropout(0.3),
+              CBR(160),
+              CBR(160),
+              CBR(160),
+              nn.MaxPool1D(2),
+              CBR(240, 1),
+              nn.Dropout(0.3),
+              # CBR(200),
+              # CBR(200),
+              # CBR(200),
+              # nn.MaxPool1D(2),
+              # CBR(300, 1)
+              )
+    for i in range(34):
+        model.add(Resnet1D(240))
+
+    model.add(# NCW
+              nn.Dropout(0.3),
+              nn.Conv1D(3000, 1, 1),
+              # NWC
+              SwapAxes(1, 2))
+
+
+def ctc_loss(net, train_features, train_labels):
+    preds = net(train_features)
+    return loss.CTCLoss()(preds, train_labels)
+
+
+def get_data_gen(data_dir, str2idx, batch_size=2):
+    files = os.listdir(data_dir)
+    new_files = []
+    for f in files:
+        if '.txt' in f:
+            new_files.append(f)
+    files = new_files
+    files = list(set(list(map(lambda f:f.split('.')[0], files))))
+    pooling_step = 8
+    # np.random.seed(10)
+    # while True:
+    features = []
+    labels = []
+    input_len = []
+    label_len = []
+    np.random.shuffle(files)
+    print('start one epoch')
+    for idx in range(0, len(files)):
+        try:
+            feature = np.loadtxt(data_dir+'/'+files[idx]+'.txt') + 1
+            #  mfcc.__call__(data_dir+'/'+files[new_idx]+'.wav')
+            label = list(open(data_dir+'/'+files[idx]+'.wav.trn').readline().split('\n')[0].replace(' ', ''))
+            label = np.array(list(map(lambda l:str2idx[l]+1, label)))
+        except Exception as e:
+            # print(e, files[idx])
+            continue
+        features.append(feature)
+        labels.append(label)
+        input_len.append(len(feature)/pooling_step-pooling_step)
+        label_len.append(len(label))
+        if len(features) == batch_size:
+            maxLenFeature = max(list(map(len, features))) //pooling_step *pooling_step + pooling_step * 2
+            maxLenLabel = max(list(map(len, labels)))
+            featuresArr = np.zeros([len(features), maxLenFeature, 39], dtype=np.float32)
+            labelsArr = np.ones([len(labels), maxLenLabel], dtype=np.float32) * 0  # (len(str2idx)+1)
+            for idx in range(len(features)):
+                featuresArr[idx, 0:len(features[idx]), :] = np.array(features[idx], dtype=np.float32)
+                labelsArr[idx, :len(labels[idx])] = np.array(labels[idx], dtype=np.float32)
+            yield featuresArr, labelsArr
+            features = []
+            labels = []
+            input_len = []
+            label_len = []
+
+
+def get_str2idx(data_dir):
+    files = os.listdir(data_dir)
+    all_words = []
+    str2idx = {}
+    idx2str = {}
+    for f in files:
+        if 'trn' in f:
+            all_words.extend(list(open(data_dir+'/'+f).readline().split('\n')[0].replace(' ', '')))
+    all_words = list(set(all_words))
+    for word, idx in enumerate(all_words):
+        str2idx[word] = idx
+        idx2str[str(idx)] = word
+    return str2idx, idx2str
+
+
+def get_iter(batch_size):
+    data_dir = './data'
+    train_iter = get_data_gen(data_dir, get_str2idx(data_dir)[1], batch_size)
+    for x, y in train_iter:
+        yield nd.array(x), nd.array(y)
+
+
+class ShowProcess():
+    """
+    显示处理进度的类
 
 Review comment:
   ok

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services