You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/08 23:32:40 UTC

[GitHub] szha closed pull request #8767: Factorization machine example & sparse example folder re-org

szha closed pull request #8767: Factorization machine example & sparse example folder re-org
URL: https://github.com/apache/incubator-mxnet/pull/8767
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/sparse/factorization_machine/README.md b/example/sparse/factorization_machine/README.md
new file mode 100644
index 0000000000..7609f31d5c
--- /dev/null
+++ b/example/sparse/factorization_machine/README.md
@@ -0,0 +1,16 @@
+Factorization Machine
+===========
+This example trains a factorization machine model using the criteo dataset.
+
+## Download the Dataset
+
+The criteo dataset is available at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#criteo
+The data was used in a competition on click-through rate prediction jointly hosted by Criteo and Kaggle in 2014,
+with 1,000,000 features. There are 45,840,617 training examples and 6,042,135 testing examples.
+It takes more than 30 GB to download and extract the dataset.
+
+## Train the Model
+
+- python train.py --train-data /path/to/criteo.kaggle2014.test.svm --test-data /path/to/criteo.kaggle2014.test.svm
+
+[Rendle, Steffen. "Factorization machines." In Data Mining (ICDM), 2010 IEEE 10th International Conference on, pp. 995-1000. IEEE, 2010. ](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
diff --git a/example/sparse/factorization_machine/metric.py b/example/sparse/factorization_machine/metric.py
new file mode 100644
index 0000000000..07a7e01e02
--- /dev/null
+++ b/example/sparse/factorization_machine/metric.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+
+@mx.metric.register
+@mx.metric.alias('log_loss')
+class LogLossMetric(mx.metric.EvalMetric):
+    """Computes the negative log-likelihood loss.
+
+    The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by
+
+    .. math::
+       -\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}),
+
+    where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for
+    :math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample
+    :math:`n` belongs to class :math:`k`.
+
+    Parameters
+    ----------
+    eps : float
+        Negative log-likelihood loss is undefined for predicted value is 0,
+        so predicted values are added with the small constant.
+    name : str
+        Name of this metric instance for display.
+    output_names : list of str, or None
+        Name of predictions that should be used when updating with update_dict.
+        By default include all predictions.
+    label_names : list of str, or None
+        Name of labels that should be used when updating with update_dict.
+        By default include all labels.
+
+    Examples
+    --------
+    >>> predicts = [mx.nd.array([[0.3], [0], [0.4]])]
+    >>> labels   = [mx.nd.array([0, 1, 1])]
+    >>> log_loss= mx.metric.NegativeLogLikelihood()
+    >>> log_loss.update(labels, predicts)
+    >>> print log_loss.get()
+    ('log-loss', 0.57159948348999023)
+    """
+    def __init__(self, eps=1e-12, name='log-loss',
+                 output_names=None, label_names=None):
+        super(LogLossMetric, self).__init__(
+            name, eps=eps,
+            output_names=output_names, label_names=label_names)
+        self.eps = eps
+
+    def update(self, labels, preds):
+        """Updates the internal evaluation result.
+
+        Parameters
+        ----------
+        labels : list of `NDArray`
+            The labels of the data.
+
+        preds : list of `NDArray`
+            Predicted values.
+        """
+        mx.metric.check_label_shapes(labels, preds)
+
+        for label, pred in zip(labels, preds):
+            label = label.asnumpy()
+            pred = pred.asnumpy()
+            pred = np.column_stack((1 - pred, pred))
+
+            label = label.ravel()
+            num_examples = pred.shape[0]
+            assert label.shape[0] == num_examples, (label.shape[0], num_examples)
+            prob = pred[np.arange(num_examples, dtype=np.int64), np.int64(label)]
+            self.sum_metric += (-np.log(prob + self.eps)).sum()
+            self.num_inst += num_examples
diff --git a/example/sparse/factorization_machine/model.py b/example/sparse/factorization_machine/model.py
new file mode 100644
index 0000000000..f0af2e650d
--- /dev/null
+++ b/example/sparse/factorization_machine/model.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+
+def factorization_machine_model(factor_size, num_features,
+                                lr_mult_config, wd_mult_config, init_config):
+    """ builds factorization machine network with proper formulation:
+    y = w_0 \sum(x_i w_i) + 0.5(\sum\sum<v_i,v_j>x_ix_j - \sum<v_iv_i>x_i^2)
+    """
+    x = mx.symbol.Variable("data", stype='csr')
+    # factor, linear and bias terms
+    v = mx.symbol.Variable("v", shape=(num_features, factor_size), stype='row_sparse',
+                           init=init_config['v'], lr_mult=lr_mult_config['v'],
+                           wd_mult=wd_mult_config['v'])
+    w = mx.symbol.var('w', shape=(num_features, 1), stype='row_sparse',
+                      init=init_config['w'], lr_mult=lr_mult_config['w'],
+                      wd_mult=wd_mult_config['w'])
+    w0 = mx.symbol.var('w0', shape=(1,), init=init_config['w0'],
+                       lr_mult=lr_mult_config['w0'], wd_mult=wd_mult_config['w0'])
+    w1 = mx.symbol.broadcast_add(mx.symbol.dot(x, w), w0)
+
+    # squared terms for subtracting self interactions
+    v_s = mx.symbol._internal._square_sum(data=v, axis=1, keepdims=True)
+    x_s = x.square()
+    bd_sum = mx.sym.dot(x_s, v_s)
+
+    # interactions
+    w2 = mx.symbol.dot(x, v)
+    w2_squared = 0.5 * mx.symbol.square(data=w2)
+
+    # putting everything together
+    w_all = mx.symbol.Concat(w1, w2_squared, dim=1)
+    sum1 = w_all.sum(axis=1, keepdims=True)
+    sum2 = -0.5 * bd_sum
+    model = sum1 + sum2
+
+    y = mx.symbol.Variable("softmax_label")
+    model = mx.symbol.LogisticRegressionOutput(data=model, label=y)
+    return model
diff --git a/example/sparse/factorization_machine/train.py b/example/sparse/factorization_machine/train.py
new file mode 100644
index 0000000000..741cf958db
--- /dev/null
+++ b/example/sparse/factorization_machine/train.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from metric import *
+from mxnet.test_utils import *
+from model import *
+import argparse, os
+
+parser = argparse.ArgumentParser(description="Run factorization machine with criteo dataset",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--data-train', type=str, default=None,
+                    help='training dataset in LibSVM format.')
+parser.add_argument('--data-test', type=str, default=None,
+                    help='test dataset in LibSVM format.')
+parser.add_argument('--num-epoch', type=int, default=1,
+                    help='number of epochs to train')
+parser.add_argument('--batch-size', type=int, default=1000,
+                    help='number of examples per batch')
+parser.add_argument('--input-size', type=int, default=1000000,
+                    help='number of features in the input')
+parser.add_argument('--factor-size', type=int, default=16,
+                    help='number of latent variables')
+parser.add_argument('--factor-lr', type=float, default=0.0001,
+                    help='learning rate for factor terms')
+parser.add_argument('--linear-lr', type=float, default=0.001,
+                    help='learning rate for linear terms')
+parser.add_argument('--bias-lr', type=float, default=0.1,
+                    help='learning rate for bias terms')
+parser.add_argument('--factor-wd', type=float, default=0.00001,
+                    help='weight decay rate for factor terms')
+parser.add_argument('--linear-wd', type=float, default=0.001,
+                    help='weight decay rate for linear terms')
+parser.add_argument('--bias-wd', type=float, default=0.01,
+                    help='weight decay rate for bias terms')
+parser.add_argument('--factor-sigma', type=float, default=0.001,
+                    help='standard deviation for initialization of factor terms')
+parser.add_argument('--linear-sigma', type=float, default=0.01,
+                    help='standard deviation for initialization of linear terms')
+parser.add_argument('--bias-sigma', type=float, default=0.01,
+                    help='standard deviation for initialization of bias terms')
+parser.add_argument('--log-interval', type=int, default=100,
+                    help='number of batches between logging messages')
+parser.add_argument('--kvstore', type=str, default='local',
+                    help='what kvstore to use', choices=["dist_async", "local"])
+
+if __name__ == '__main__':
+    import logging
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+
+    # arg parser
+    args = parser.parse_args()
+    logging.info(args)
+    num_epoch = args.num_epoch
+    batch_size = args.batch_size
+    kvstore = args.kvstore
+    factor_size = args.factor_size
+    num_features = args.input_size
+    log_interval = args.log_interval
+    assert(args.data_train is not None and args.data_test is not None), \
+          "dataset for training or test is missing"
+
+    # create kvstore
+    kv = mx.kvstore.create(kvstore)
+    # data iterator
+    train_data = mx.io.LibSVMIter(data_libsvm=args.data_train, data_shape=(num_features,),
+                                  batch_size=batch_size)
+    eval_data = mx.io.LibSVMIter(data_libsvm=args.data_test, data_shape=(num_features,),
+                                 batch_size=batch_size)
+    # model
+    lr_config = {'v': args.factor_lr, 'w': args.linear_lr, 'w0': args.bias_lr}
+    wd_config = {'v': args.factor_wd, 'w': args.linear_wd, 'w0': args.bias_wd}
+    init_config = {'v': mx.initializer.Normal(args.factor_sigma),
+                   'w': mx.initializer.Normal(args.linear_sigma),
+                   'w0': mx.initializer.Normal(args.bias_sigma)}
+    model = factorization_machine_model(factor_size, num_features, lr_config, wd_config, init_config)
+
+    # module
+    mod = mx.mod.Module(symbol=model)
+    mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
+    mod.init_params()
+    optimizer_params=(('learning_rate', 1), ('wd', 1), ('beta1', 0.9),
+                      ('beta2', 0.999), ('epsilon', 1e-8))
+    mod.init_optimizer(optimizer='adam', kvstore=kv, optimizer_params=optimizer_params)
+
+    # metrics
+    metric = mx.metric.create(['log_loss'])
+    speedometer = mx.callback.Speedometer(batch_size, log_interval)
+
+    # get the sparse weight parameter
+    w_index = mod._exec_group.param_names.index('w')
+    w_param = mod._exec_group.param_arrays[w_index]
+    v_index = mod._exec_group.param_names.index('v')
+    v_param = mod._exec_group.param_arrays[v_index]
+
+    logging.info('Training started ...')
+    train_iter = iter(train_data)
+    eval_iter = iter(eval_data)
+    for epoch in range(num_epoch):
+        nbatch = 0
+        metric.reset()
+        for batch in train_iter:
+            nbatch += 1
+            # manually pull sparse weights from kvstore so that _square_sum
+            # only computes the rows necessary
+            row_ids = batch.data[0].indices
+            kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index)
+            kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index)
+            mod.forward_backward(batch)
+            # update all parameters (including the weight parameter)
+            mod.update()
+            # update training metric
+            mod.update_metric(metric, batch.label)
+            speedometer_param = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                       eval_metric=metric, locals=locals())
+            speedometer(speedometer_param)
+
+        # pull all updated rows before validation
+        kv.row_sparse_pull('w', w_param, row_ids=[row_ids], priority=-w_index)
+        kv.row_sparse_pull('v', v_param, row_ids=[row_ids], priority=-v_index)
+        # evaluate metric on validation dataset
+        score = mod.score(eval_iter, ['log_loss'])
+        logging.info("epoch %d, eval log loss = %s" % (epoch, score[0][1]))
+        # reset the iterator for next pass of data
+        train_iter.reset()
+        eval_iter.reset()
+    logging.info('Training completed.')
diff --git a/example/sparse/linear_classification/README.md b/example/sparse/linear_classification/README.md
new file mode 100644
index 0000000000..7e2a7ad37f
--- /dev/null
+++ b/example/sparse/linear_classification/README.md
@@ -0,0 +1,17 @@
+Linear Classification Using Sparse Matrix Multiplication
+===========
+This examples trains a linear model using the sparse feature in MXNet. This is for demonstration purpose only.
+
+The example utilizes the sparse data loader ([mx.io.LibSVMIter](https://mxnet.incubator.apache.org/versions/master/api/python/io.html#mxnet.io.LibSVMIter)),
+the sparse dot operator and [sparse gradient updaters](https://mxnet.incubator.apache.org/versions/master/api/python/ndarray/sparse.html#updater)
+to train a linear model on the
+[Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset.
+
+The example also shows how to perform distributed training with the sparse feature.
+
+- `python train.py`
+
+Notes on Distributed Training:
+
+- For distributed training, please use the `../../tools/launch.py` script to launch a cluster.
+- For example, to run two workers and two servers with one machine, run `../../../tools/launch.py -n 2 --launcher=local python train.py --kvstore=dist_async`
diff --git a/example/sparse/linear_classification/data.py b/example/sparse/linear_classification/data.py
new file mode 100644
index 0000000000..02984734fb
--- /dev/null
+++ b/example/sparse/linear_classification/data.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os, gzip
+import sys
+import mxnet as mx
+
+def get_avazu_data(data_dir, data_name, url):
+    if not os.path.isdir(data_dir):
+        os.mkdir(data_dir)
+    os.chdir(data_dir)
+    if (not os.path.exists(data_name)):
+        print("Dataset " + data_name + " not present. Downloading now ...")
+        import urllib
+        zippath = os.path.join(data_dir, data_name + ".bz2")
+        urllib.urlretrieve(url + data_name + ".bz2", zippath)
+        os.system("bzip2 -d %r" % data_name + ".bz2")
+        print("Dataset " + data_name + " is now present.")
+    os.chdir("..")
diff --git a/example/sparse/linear_model.py b/example/sparse/linear_classification/linear_model.py
similarity index 100%
rename from example/sparse/linear_model.py
rename to example/sparse/linear_classification/linear_model.py
diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification/train.py
similarity index 95%
rename from example/sparse/linear_classification.py
rename to example/sparse/linear_classification/train.py
index 1d63c55b11..eb7871bbdb 100644
--- a/example/sparse/linear_classification.py
+++ b/example/sparse/linear_classification/train.py
@@ -17,7 +17,7 @@
 
 import mxnet as mx
 from mxnet.test_utils import *
-from get_data import get_libsvm_data
+from data import get_avazu_data
 from linear_model import *
 import argparse
 import os
@@ -67,8 +67,8 @@
     data_dir = os.path.join(os.getcwd(), 'data')
     train_data = os.path.join(data_dir, AVAZU['train'])
     val_data = os.path.join(data_dir, AVAZU['test'])
-    get_libsvm_data(data_dir, AVAZU['train'], AVAZU['url'])
-    get_libsvm_data(data_dir, AVAZU['test'], AVAZU['url'])
+    get_avazu_data(data_dir, AVAZU['train'], AVAZU['url'])
+    get_avazu_data(data_dir, AVAZU['test'], AVAZU['url'])
 
     # data iterator
     train_data = mx.io.LibSVMIter(data_libsvm=train_data, data_shape=(num_features,),
@@ -100,11 +100,10 @@
     speedometer = mx.callback.Speedometer(batch_size, 100)
 
     logging.info('Training started ...')
-    data_iter = iter(train_data)
     for epoch in range(num_epoch):
         nbatch = 0
         metric.reset()
-        for batch in data_iter:
+        for batch in train_data:
             nbatch += 1
             # for distributed training, we need to manually pull sparse weights from kvstore
             if kv:
@@ -129,5 +128,6 @@
         save_optimizer_states = 'dist' not in kv.type if kv else True
         mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=save_optimizer_states)
         # reset the iterator for next pass of data
-        data_iter.reset()
+        train_data.reset()
+        eval_data.reset()
     logging.info('Training completed.')
diff --git a/example/sparse/weighted_softmax_ce.py b/example/sparse/linear_classification/weighted_softmax_ce.py
similarity index 100%
rename from example/sparse/weighted_softmax_ce.py
rename to example/sparse/linear_classification/weighted_softmax_ce.py
diff --git a/example/sparse/matrix_factorization/README.md b/example/sparse/matrix_factorization/README.md
new file mode 100644
index 0000000000..3ada5e8015
--- /dev/null
+++ b/example/sparse/matrix_factorization/README.md
@@ -0,0 +1,8 @@
+Matrix Factorization w/ Sparse Embedding
+===========
+The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples.
+The operator is available on both CPU and GPU. This is for demonstration purpose only.
+
+- `python train.py`
+- To compare the training speed with (dense) Embedding, run `python train.py --use-dense`
+- To run the example on the GPU, run `python train.py --use-gpu`
diff --git a/example/sparse/get_data.py b/example/sparse/matrix_factorization/data.py
similarity index 68%
rename from example/sparse/get_data.py
rename to example/sparse/matrix_factorization/data.py
index 19c635fe33..fae2c237c8 100644
--- a/example/sparse/get_data.py
+++ b/example/sparse/matrix_factorization/data.py
@@ -18,38 +18,7 @@
 import os, gzip
 import sys
 import mxnet as mx
-
-class DummyIter(mx.io.DataIter):
-    "A dummy iterator that always return the same batch, used for speed testing"
-    def __init__(self, real_iter):
-        super(DummyIter, self).__init__()
-        self.real_iter = real_iter
-        self.provide_data = real_iter.provide_data
-        self.provide_label = real_iter.provide_label
-        self.batch_size = real_iter.batch_size
-
-        for batch in real_iter:
-            self.the_batch = batch
-            break
-
-    def __iter__(self):
-        return self
-
-    def next(self):
-        return self.the_batch
-
-def get_libsvm_data(data_dir, data_name, url):
-    if not os.path.isdir(data_dir):
-        os.mkdir(data_dir)
-    os.chdir(data_dir)
-    if (not os.path.exists(data_name)):
-        print("Dataset " + data_name + " not present. Downloading now ...")
-        import urllib
-        zippath = os.path.join(data_dir, data_name + ".bz2")
-        urllib.urlretrieve(url + data_name + ".bz2", zippath)
-        os.system("bzip2 -d %r" % data_name + ".bz2")
-        print("Dataset " + data_name + " is now present.")
-    os.chdir("..")
+from mxnet.test_utils import DummyIter
 
 def get_movielens_data(prefix):
     if not os.path.exists("%s.zip" % prefix):
diff --git a/example/sparse/matrix_fact_model.py b/example/sparse/matrix_factorization/model.py
similarity index 100%
rename from example/sparse/matrix_fact_model.py
rename to example/sparse/matrix_factorization/model.py
diff --git a/example/sparse/matrix_factorization.py b/example/sparse/matrix_factorization/train.py
similarity index 97%
rename from example/sparse/matrix_factorization.py
rename to example/sparse/matrix_factorization/train.py
index 3387706665..14c6ca188f 100644
--- a/example/sparse/matrix_factorization.py
+++ b/example/sparse/matrix_factorization/train.py
@@ -20,9 +20,8 @@
 import time
 import mxnet as mx
 import numpy as np
-from get_data import get_movielens_iter, get_movielens_data
-from matrix_fact_model import matrix_fact_net
-
+from data import get_movielens_iter, get_movielens_data
+from model import matrix_fact_net
 
 logging.basicConfig(level=logging.DEBUG)
 
diff --git a/example/sparse/readme.md b/example/sparse/readme.md
deleted file mode 100644
index e443bfa2d5..0000000000
--- a/example/sparse/readme.md
+++ /dev/null
@@ -1,21 +0,0 @@
-Example
-===========
-This folder contains examples using the sparse feature in MXNet. They are for demonstration purpose only.
-
-## Linear Classification Using Sparse Matrix Multiplication
-
-The example demonstrates the basic usage of the sparse feature in MXNet to speedup computation. It utilizes the sparse data loader, sparse operators and a sparse gradient updater to train a linear model on the [Avazu](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html#avazu) click-through-prediction dataset.
-
-- `python linear_classification.py`
-
-Notes on Distributed Training:
-
-- For distributed training, please use the `../../tools/launch.py` script to launch a cluster.
-- For example, to run two workers and two servers with one machine, run `../../tools/launch.py -n 2 --launcher=local python linear_classification.py --kvstore=dist_async`
-
-## Matrix Factorization Using Sparse Embedding
-
-The example demonstrates the basic usage of the SparseEmbedding operator in MXNet, adapted based on @leopd's recommender examples.
-
-- `python matrix_factorization.py`
-- To compare the train speed with (dense) Embedding, run `python matrix_factorization.py --use-dense`
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 3e667364cd..0dfeec56c1 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1538,3 +1538,34 @@ def discard_stderr():
     finally:
         os.dup2(old_stderr, stderr_fileno)
         bit_bucket.close()
+
+class DummyIter(mx.io.DataIter):
+    """A dummy iterator that always returns the same batch of data
+    (the first data batch of the real data iter). This is usually used for speed testing.
+
+    Parameters
+    ----------
+    real_iter: mx.io.DataIter
+        The real data iterator where the first batch of data comes from
+    """
+    def __init__(self, real_iter):
+        super(DummyIter, self).__init__()
+        self.real_iter = real_iter
+        self.provide_data = real_iter.provide_data
+        self.provide_label = real_iter.provide_label
+        self.batch_size = real_iter.batch_size
+        self.the_batch = next(real_iter)
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        """Get a data batch from iterator. The first data batch of real iter is always returned.
+        StopIteration will never be raised.
+
+        Returns
+        -------
+        DataBatch
+            The data of next batch.
+        """
+        return self.the_batch


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services