You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/11 04:29:26 UTC

[GitHub] eric-haibin-lin closed pull request #8180: Add wide and deep model into sparse example

eric-haibin-lin closed pull request #8180: Add wide and deep model into sparse example
URL: https://github.com/apache/incubator-mxnet/pull/8180
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/sparse/wide_deep/README.md b/example/sparse/wide_deep/README.md
new file mode 100644
index 0000000000..a538106216
--- /dev/null
+++ b/example/sparse/wide_deep/README.md
@@ -0,0 +1,7 @@
+## Wide and Deep Learning
+
+The example demonstrates how to train [wide and deep model](https://arxiv.org/abs/1606.07792). The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this example uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). Tricks of feature engineering are adapted from tensorflow's [wide and deep tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
+
+The final accuracy should be around 85%.
+
+- `python wide_deep_classification.py`
diff --git a/example/sparse/wide_deep/data.py b/example/sparse/wide_deep/data.py
new file mode 100644
index 0000000000..ffac1eb422
--- /dev/null
+++ b/example/sparse/wide_deep/data.py
@@ -0,0 +1,139 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from csv import DictReader
+import os
+import mxnet as mx
+import numpy as np
+
+
+def get_uci_adult(data_dir, data_name, url):
+    if not os.path.isdir(data_dir):
+        os.mkdir(data_dir)
+    os.chdir(data_dir)
+    if (not os.path.exists(data_name)):
+        print("Dataset " + data_name + " not present. Downloading now ...")
+        os.system("wget %r" % url + data_name)
+        if "test" in data_name:
+            os.system("sed -i '1d' %r" % data_name)
+        print("Dataset " + data_name + " is now present.")
+    csr, dns, label = preprocess_uci_adult(data_name)
+    os.chdir("..")
+    return csr, dns, label
+
+
+def preprocess_uci_adult(data_name):
+    """Some tricks of feature engineering are adapted
+    from tensorflow's wide and deep tutorial.
+    """
+    csv_columns = [
+        "age", "workclass", "fnlwgt", "education", "education_num",
+        "marital_status", "occupation", "relationship", "race", "gender",
+        "capital_gain", "capital_loss", "hours_per_week", "native_country",
+        "income_bracket"
+    ]
+
+    vocabulary_dict = {
+        "gender": [
+            "Female", "Male"
+        ],
+        "education": [
+            "Bachelors", "HS-grad", "11th", "Masters", "9th",
+            "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
+            "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
+            "Preschool", "12th"
+        ],
+        "marital_status": [
+            "Married-civ-spouse", "Divorced", "Married-spouse-absent",
+            "Never-married", "Separated", "Married-AF-spouse", "Widowed"
+        ],
+        "relationship": [
+            "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
+            "Other-relative"
+        ],
+        "workclass": [
+            "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
+            "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
+        ]
+    }
+    # wide columns
+    crossed_columns = [
+        ["education", "occupation"],
+        ["native_country", "occupation"],
+        ["age_buckets", "education", "occupation"],
+    ]
+    age_boundaries = [18, 25, 30, 35, 40, 45, 50, 55, 60, 65]
+    # deep columns
+    indicator_columns = ['workclass', 'education', 'gender', 'relationship']
+    
+    embedding_columns = ['native_country', 'occupation']
+
+    continuous_columns = ['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
+    # income_bracket column is the label
+    labels = ["<", ">"]
+
+    hash_bucket_size = 1000
+    
+    csr_ncols = len(crossed_columns) * hash_bucket_size
+    dns_ncols = len(continuous_columns) + len(embedding_columns)
+    for col in indicator_columns:
+        dns_ncols += len(vocabulary_dict[col])
+
+    label_list = []
+    csr_list = []
+    dns_list = []
+
+    with open(data_name) as f:
+        for row in DictReader(f, fieldnames=csv_columns):
+            label_list.append(labels.index(row['income_bracket'].strip()[0]))
+
+            for i, cols in enumerate(crossed_columns):
+                if cols[0] == "age_buckets":
+                    age_bucket = np.digitize(float(row["age"]), age_boundaries)
+                    s = '_'.join([row[col].strip() for col in cols[1:]])
+                    s += '_' + str(age_bucket)
+                    csr_list.append((i * hash_bucket_size + hash(s) % hash_bucket_size, 1.0))
+                else:
+                    s = '_'.join([row[col].strip() for col in cols])
+                    csr_list.append((i * hash_bucket_size + hash(s) % hash_bucket_size, 1.0))
+            
+            dns_row = [0] * dns_ncols
+            dns_dim = 0
+            for col in embedding_columns:
+                dns_row[dns_dim] = hash(row[col].strip()) % hash_bucket_size
+                dns_dim += 1
+
+            for col in indicator_columns:
+                dns_row[dns_dim + vocabulary_dict[col].index(row[col].strip())] = 1.0
+                dns_dim += len(vocabulary_dict[col])
+
+            for col in continuous_columns:
+                dns_row[dns_dim] = float(row[col].strip())
+                dns_dim += 1
+
+            dns_list.append(dns_row)
+
+    data_list = [item[1] for item in csr_list]
+    indices_list = [item[0] for item in csr_list]
+    indptr_list = range(0, len(indices_list) + 1, len(crossed_columns))
+    # convert to ndarrays
+    csr = mx.nd.sparse.csr_matrix((data_list, indices_list, indptr_list),
+                                  shape=(len(label_list), hash_bucket_size * len(crossed_columns)))
+    dns = np.array(dns_list)
+    label = np.array(label_list)
+    return csr, dns, label
diff --git a/example/sparse/wide_deep/model.py b/example/sparse/wide_deep/model.py
new file mode 100644
index 0000000000..e8ba5318b5
--- /dev/null
+++ b/example/sparse/wide_deep/model.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+
+def wide_deep_model(num_linear_features, num_embed_features, num_cont_features, 
+                    input_dims, hidden_units):
+    # wide model
+    csr_data = mx.symbol.Variable("csr_data", stype='csr')
+    label = mx.symbol.Variable("softmax_label")
+
+    norm_init = mx.initializer.Normal(sigma=0.01)
+    # weight with row_sparse storage type to enable sparse gradient updates
+    weight = mx.symbol.Variable("linear_weight", shape=(num_linear_features, 2),
+                                init=norm_init, stype='row_sparse')
+    bias = mx.symbol.Variable("linear_bias", shape=(2,))
+    dot = mx.symbol.sparse.dot(csr_data, weight)
+    linear_out = mx.symbol.broadcast_add(dot, bias)
+    # deep model
+    dns_data = mx.symbol.Variable("dns_data")
+    # embedding features
+    x = mx.symbol.slice(data=dns_data, begin=(0, 0),
+                        end=(None, num_embed_features))
+    embeds = mx.symbol.split(data=x, num_outputs=num_embed_features, squeeze_axis=1)
+    # continuous features
+    x = mx.symbol.slice(data=dns_data, begin=(0, num_embed_features),
+                        end=(None, num_embed_features + num_cont_features))
+    features = [x]
+
+    for i, embed in enumerate(embeds):
+        embed_weight = mx.symbol.Variable('embed_%d_weight' % i, stype='row_sparse')
+        features.append(mx.symbol.contrib.SparseEmbedding(data=embed, weight=embed_weight,
+                        input_dim=input_dims[i], output_dim=hidden_units[0]))
+
+    hidden = mx.symbol.concat(*features, dim=1)
+    hidden = mx.symbol.FullyConnected(data=hidden, num_hidden=hidden_units[1])
+    hideen = mx.symbol.Activation(data=hidden, act_type='relu')
+    hidden = mx.symbol.FullyConnected(data=hidden, num_hidden=hidden_units[2])
+    hideen = mx.symbol.Activation(data=hidden, act_type='relu')
+    deep_out = mx.symbol.FullyConnected(data=hidden, num_hidden=2)
+
+    out = mx.symbol.SoftmaxOutput(linear_out + deep_out, label, name='model')
+    return out
diff --git a/example/sparse/wide_deep/train.py b/example/sparse/wide_deep/train.py
new file mode 100644
index 0000000000..89befb5aa8
--- /dev/null
+++ b/example/sparse/wide_deep/train.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import *
+from data import *
+from model import *
+import argparse
+import os
+
+
+parser = argparse.ArgumentParser(description="Run sparse wide and deep classification ",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-epoch', type=int, default=10,
+                    help='number of epochs to train')
+parser.add_argument('--batch-size', type=int, default=100,
+                    help='number of examples per batch')
+parser.add_argument('--lr', type=float, default=0.001,
+                    help='learning rate')
+parser.add_argument('--cuda', action='store_true', default=False,
+                    help='Train on GPU with CUDA')
+parser.add_argument('--optimizer', type=str, default='adam',
+                    help='what optimizer to use',
+                    choices=["ftrl", "sgd", "adam"])
+parser.add_argument('--log-interval', type=int, default=100,
+                    help='number of batches to wait before logging training status')
+
+
+# Related to feature engineering, please see preprocess in data.py
+ADULT = {
+    'train': 'adult.data',
+    'test': 'adult.test',
+    'url': 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/',
+    'num_linear_features': 3000,
+    'num_embed_features': 2,
+    'num_cont_features': 38,
+    'embed_input_dims': [1000, 1000],
+    'hidden_units': [8, 50, 100],
+}
+
+
+if __name__ == '__main__':
+    import logging
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+
+    # arg parser
+    args = parser.parse_args()
+    logging.info(args)
+    num_epoch = args.num_epoch
+    batch_size = args.batch_size
+    optimizer = args.optimizer
+    log_interval = args.log_interval
+    lr = args.lr
+    ctx = mx.gpu(0) if args.cuda else mx.cpu()
+
+    # dataset    
+    data_dir = os.path.join(os.getcwd(), 'data')
+    train_data = os.path.join(data_dir, ADULT['train'])
+    val_data = os.path.join(data_dir, ADULT['test'])
+    train_csr, train_dns, train_label = get_uci_adult(data_dir, ADULT['train'], ADULT['url'])
+    val_csr, val_dns, val_label = get_uci_adult(data_dir, ADULT['test'], ADULT['url'])
+
+    model = wide_deep_model(ADULT['num_linear_features'], ADULT['num_embed_features'],
+                            ADULT['num_cont_features'], ADULT['embed_input_dims'],
+                            ADULT['hidden_units'])
+
+    # data iterator
+    train_data = mx.io.NDArrayIter({'csr_data': train_csr, 'dns_data': train_dns},
+                                   {'softmax_label': train_label}, batch_size,
+                                   shuffle=True, last_batch_handle='discard')
+    eval_data = mx.io.NDArrayIter({'csr_data': val_csr, 'dns_data': val_dns},
+                                  {'softmax_label': val_label}, batch_size,
+                                  shuffle=True, last_batch_handle='discard')
+    
+    # module
+    mod = mx.mod.Module(symbol=model, context=ctx ,data_names=['csr_data', 'dns_data'],
+                        label_names=['softmax_label'])
+    mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    mod.init_params()
+    optim = mx.optimizer.create(optimizer, learning_rate=lr, rescale_grad=1.0/batch_size)
+    mod.init_optimizer(optimizer=optim)
+    # use accuracy as the metric
+    metric = mx.metric.create(['acc'])
+    # get the sparse weight parameter
+    speedometer = mx.callback.Speedometer(batch_size, log_interval)
+
+    logging.info('Training started ...')
+    
+    data_iter = iter(train_data)
+    for epoch in range(num_epoch):
+        nbatch = 0
+        metric.reset()
+        for batch in data_iter:
+            nbatch += 1
+            mod.forward_backward(batch)
+            # update all parameters (including the weight parameter)
+            mod.update()
+            # update training metric
+            mod.update_metric(metric, batch.label)
+            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                       eval_metric=metric, locals=locals())
+            speedometer(speedometer_param)
+        # evaluate metric on validation dataset
+        score = mod.score(eval_data, ['acc'])
+        logging.info('epoch %d, accuracy = %s' % (epoch, score[0][1]))
+        
+        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=True)
+        # reset the iterator for next pass of data
+        data_iter.reset()
+    
+    logging.info('Training completed.')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services