You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2017/12/11 04:29:27 UTC

[incubator-mxnet] branch master updated: Add wide and deep model into sparse example (#8180)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 20666c3  Add wide and deep model into sparse example (#8180)
20666c3 is described below

commit 20666c3186b3d7166cae625b48c81713c6811d67
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Mon Dec 11 12:29:22 2017 +0800

    Add wide and deep model into sparse example (#8180)
    
    * get uci data
    
    * wide and deep symbol
    
    * add wide and deep model and adult dataset in classification
    
    * typo
    
    * Add license header
    
    * separate linear and wide-deep classification, fix for row_ids
    
    * add sparse embedding and remove kvstore
    
    * minor fix
    
    * replace libsvm iter with ndarray iter
    
    * update
    
    * more comments and remove weighted ce loss
    
    * update
    
    * Update wide_deep_model.py
    
    * Update wide_deep_classification.py
    
    * add a cross column feature
    
    * update
    
    * trigger
    
    * add gpu support
    
    * move to wide_deep folder
    
    * update
    
    * update
---
 example/sparse/wide_deep/README.md |   7 ++
 example/sparse/wide_deep/data.py   | 139 +++++++++++++++++++++++++++++++++++++
 example/sparse/wide_deep/model.py  |  58 ++++++++++++++++
 example/sparse/wide_deep/train.py  | 126 +++++++++++++++++++++++++++++++++
 4 files changed, 330 insertions(+)

diff --git a/example/sparse/wide_deep/README.md b/example/sparse/wide_deep/README.md
new file mode 100644
index 0000000..a538106
--- /dev/null
+++ b/example/sparse/wide_deep/README.md
@@ -0,0 +1,7 @@
+## Wide and Deep Learning
+
+The example demonstrates how to train [wide and deep model](https://arxiv.org/abs/1606.07792). The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this example uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). Tricks of feature engineering are adapted from tensorflow's [wide and deep tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
+
+The final accuracy should be around 85%.
+
+- `python wide_deep_classification.py`
diff --git a/example/sparse/wide_deep/data.py b/example/sparse/wide_deep/data.py
new file mode 100644
index 0000000..ffac1eb
--- /dev/null
+++ b/example/sparse/wide_deep/data.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from csv import DictReader
+import os
+import mxnet as mx
+import numpy as np
+
+
+def get_uci_adult(data_dir, data_name, url):
+    if not os.path.isdir(data_dir):
+        os.mkdir(data_dir)
+    os.chdir(data_dir)
+    if (not os.path.exists(data_name)):
+        print("Dataset " + data_name + " not present. Downloading now ...")
+        os.system("wget %r" % url + data_name)
+        if "test" in data_name:
+            os.system("sed -i '1d' %r" % data_name)
+        print("Dataset " + data_name + " is now present.")
+    csr, dns, label = preprocess_uci_adult(data_name)
+    os.chdir("..")
+    return csr, dns, label
+
+
+def preprocess_uci_adult(data_name):
+    """Some tricks of feature engineering are adapted
+    from tensorflow's wide and deep tutorial.
+    """
+    csv_columns = [
+        "age", "workclass", "fnlwgt", "education", "education_num",
+        "marital_status", "occupation", "relationship", "race", "gender",
+        "capital_gain", "capital_loss", "hours_per_week", "native_country",
+        "income_bracket"
+    ]
+
+    vocabulary_dict = {
+        "gender": [
+            "Female", "Male"
+        ],
+        "education": [
+            "Bachelors", "HS-grad", "11th", "Masters", "9th",
+            "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
+            "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
+            "Preschool", "12th"
+        ],
+        "marital_status": [
+            "Married-civ-spouse", "Divorced", "Married-spouse-absent",
+            "Never-married", "Separated", "Married-AF-spouse", "Widowed"
+        ],
+        "relationship": [
+            "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
+            "Other-relative"
+        ],
+        "workclass": [
+            "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
+            "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
+        ]
+    }
+    # wide columns
+    crossed_columns = [
+        ["education", "occupation"],
+        ["native_country", "occupation"],
+        ["age_buckets", "education", "occupation"],
+    ]
+    age_boundaries = [18, 25, 30, 35, 40, 45, 50, 55, 60, 65]
+    # deep columns
+    indicator_columns = ['workclass', 'education', 'gender', 'relationship']
+    
+    embedding_columns = ['native_country', 'occupation']
+
+    continuous_columns = ['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
+    # income_bracket column is the label
+    labels = ["<", ">"]
+
+    hash_bucket_size = 1000
+    
+    csr_ncols = len(crossed_columns) * hash_bucket_size
+    dns_ncols = len(continuous_columns) + len(embedding_columns)
+    for col in indicator_columns:
+        dns_ncols += len(vocabulary_dict[col])
+
+    label_list = []
+    csr_list = []
+    dns_list = []
+
+    with open(data_name) as f:
+        for row in DictReader(f, fieldnames=csv_columns):
+            label_list.append(labels.index(row['income_bracket'].strip()[0]))
+
+            for i, cols in enumerate(crossed_columns):
+                if cols[0] == "age_buckets":
+                    age_bucket = np.digitize(float(row["age"]), age_boundaries)
+                    s = '_'.join([row[col].strip() for col in cols[1:]])
+                    s += '_' + str(age_bucket)
+                    csr_list.append((i * hash_bucket_size + hash(s) % hash_bucket_size, 1.0))
+                else:
+                    s = '_'.join([row[col].strip() for col in cols])
+                    csr_list.append((i * hash_bucket_size + hash(s) % hash_bucket_size, 1.0))
+            
+            dns_row = [0] * dns_ncols
+            dns_dim = 0
+            for col in embedding_columns:
+                dns_row[dns_dim] = hash(row[col].strip()) % hash_bucket_size
+                dns_dim += 1
+
+            for col in indicator_columns:
+                dns_row[dns_dim + vocabulary_dict[col].index(row[col].strip())] = 1.0
+                dns_dim += len(vocabulary_dict[col])
+
+            for col in continuous_columns:
+                dns_row[dns_dim] = float(row[col].strip())
+                dns_dim += 1
+
+            dns_list.append(dns_row)
+
+    data_list = [item[1] for item in csr_list]
+    indices_list = [item[0] for item in csr_list]
+    indptr_list = range(0, len(indices_list) + 1, len(crossed_columns))
+    # convert to ndarrays
+    csr = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list),
+                                  shape=(len(label_list), hash_bucket_size * len(crossed_columns)))
+    dns = np.array(dns_list)
+    label = np.array(label_list)
+    return csr, dns, label
diff --git a/example/sparse/wide_deep/model.py b/example/sparse/wide_deep/model.py
new file mode 100644
index 0000000..e8ba531
--- /dev/null
+++ b/example/sparse/wide_deep/model.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+
+def wide_deep_model(num_linear_features, num_embed_features, num_cont_features, 
+                    input_dims, hidden_units):
+    # wide model
+    csr_data = mx.symbol.Variable("csr_data", stype='csr')
+    label = mx.symbol.Variable("softmax_label")
+
+    norm_init = mx.initializer.Normal(sigma=0.01)
+    # weight with row_sparse storage type to enable sparse gradient updates
+    weight = mx.symbol.Variable("linear_weight", shape=(num_linear_features, 2),
+                                init=norm_init, stype='row_sparse')
+    bias = mx.symbol.Variable("linear_bias", shape=(2,))
+    dot = mx.symbol.sparse.dot(csr_data, weight)
+    linear_out = mx.symbol.broadcast_add(dot, bias)
+    # deep model
+    dns_data = mx.symbol.Variable("dns_data")
+    # embedding features
+    x = mx.symbol.slice(data=dns_data, begin=(0, 0),
+                        end=(None, num_embed_features))
+    embeds = mx.symbol.split(data=x, num_outputs=num_embed_features, squeeze_axis=1)
+    # continuous features
+    x = mx.symbol.slice(data=dns_data, begin=(0, num_embed_features),
+                        end=(None, num_embed_features + num_cont_features))
+    features = [x]
+
+    for i, embed in enumerate(embeds):
+        embed_weight = mx.symbol.Variable('embed_%d_weight' % i, stype='row_sparse')
+        features.append(mx.symbol.contrib.SparseEmbedding(data=embed, weight=embed_weight,
+                        input_dim=input_dims[i], output_dim=hidden_units[0]))
+
+    hidden = mx.symbol.concat(*features, dim=1)
+    hidden = mx.symbol.FullyConnected(data=hidden, num_hidden=hidden_units[1])
+    hideen = mx.symbol.Activation(data=hidden, act_type='relu')
+    hidden = mx.symbol.FullyConnected(data=hidden, num_hidden=hidden_units[2])
+    hideen = mx.symbol.Activation(data=hidden, act_type='relu')
+    deep_out = mx.symbol.FullyConnected(data=hidden, num_hidden=2)
+
+    out = mx.symbol.SoftmaxOutput(linear_out + deep_out, label, name='model')
+    return out
diff --git a/example/sparse/wide_deep/train.py b/example/sparse/wide_deep/train.py
new file mode 100644
index 0000000..89befb5
--- /dev/null
+++ b/example/sparse/wide_deep/train.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import *
+from data import *
+from model import *
+import argparse
+import os
+
+
+parser = argparse.ArgumentParser(description="Run sparse wide and deep classification ",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-epoch', type=int, default=10,
+                    help='number of epochs to train')
+parser.add_argument('--batch-size', type=int, default=100,
+                    help='number of examples per batch')
+parser.add_argument('--lr', type=float, default=0.001,
+                    help='learning rate')
+parser.add_argument('--cuda', action='store_true', default=False,
+                    help='Train on GPU with CUDA')
+parser.add_argument('--optimizer', type=str, default='adam',
+                    help='what optimizer to use',
+                    choices=["ftrl", "sgd", "adam"])
+parser.add_argument('--log-interval', type=int, default=100,
+                    help='number of batches to wait before logging training status')
+
+
+# Related to feature engineering, please see preprocess in data.py
+ADULT = {
+    'train': 'adult.data',
+    'test': 'adult.test',
+    'url': 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/',
+    'num_linear_features': 3000,
+    'num_embed_features': 2,
+    'num_cont_features': 38,
+    'embed_input_dims': [1000, 1000],
+    'hidden_units': [8, 50, 100],
+}
+
+
+if __name__ == '__main__':
+    import logging
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+
+    # arg parser
+    args = parser.parse_args()
+    logging.info(args)
+    num_epoch = args.num_epoch
+    batch_size = args.batch_size
+    optimizer = args.optimizer
+    log_interval = args.log_interval
+    lr = args.lr
+    ctx = mx.gpu(0) if args.cuda else mx.cpu()
+
+    # dataset    
+    data_dir = os.path.join(os.getcwd(), 'data')
+    train_data = os.path.join(data_dir, ADULT['train'])
+    val_data = os.path.join(data_dir, ADULT['test'])
+    train_csr, train_dns, train_label = get_uci_adult(data_dir, ADULT['train'], ADULT['url'])
+    val_csr, val_dns, val_label = get_uci_adult(data_dir, ADULT['test'], ADULT['url'])
+
+    model = wide_deep_model(ADULT['num_linear_features'], ADULT['num_embed_features'],
+                            ADULT['num_cont_features'], ADULT['embed_input_dims'],
+                            ADULT['hidden_units'])
+
+    # data iterator
+    train_data = mx.io.NDArrayIter({'csr_data': train_csr, 'dns_data': train_dns},
+                                   {'softmax_label': train_label}, batch_size,
+                                   shuffle=True, last_batch_handle='discard')
+    eval_data = mx.io.NDArrayIter({'csr_data': val_csr, 'dns_data': val_dns},
+                                  {'softmax_label': val_label}, batch_size,
+                                  shuffle=True, last_batch_handle='discard')
+    
+    # module
+    mod = mx.mod.Module(symbol=model, context=ctx ,data_names=['csr_data', 'dns_data'],
+                        label_names=['softmax_label'])
+    mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    mod.init_params()
+    optim = mx.optimizer.create(optimizer, learning_rate=lr, rescale_grad=1.0/batch_size)
+    mod.init_optimizer(optimizer=optim)
+    # use accuracy as the metric
+    metric = mx.metric.create(['acc'])
+    # get the sparse weight parameter
+    speedometer = mx.callback.Speedometer(batch_size, log_interval)
+
+    logging.info('Training started ...')
+    
+    data_iter = iter(train_data)
+    for epoch in range(num_epoch):
+        nbatch = 0
+        metric.reset()
+        for batch in data_iter:
+            nbatch += 1
+            mod.forward_backward(batch)
+            # update all parameters (including the weight parameter)
+            mod.update()
+            # update training metric
+            mod.update_metric(metric, batch.label)
+            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                       eval_metric=metric, locals=locals())
+            speedometer(speedometer_param)
+        # evaluate metric on validation dataset
+        score = mod.score(eval_data, ['acc'])
+        logging.info('epoch %d, accuracy = %s' % (epoch, score[0][1]))
+        
+        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=True)
+        # reset the iterator for next pass of data
+        data_iter.reset()
+    
+    logging.info('Training completed.')

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].