You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/02/24 06:50:15 UTC

[GitHub] [incubator-mxnet] MrChengmo opened a new issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

MrChengmo opened a new issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949


   ## Description
   (A clear and concise description of what the bug is.)
   Hi~ I try to use a sparse embedding and three layers of DNN to train the recommendation algorithm of criteo dataset.
   
   here is my network:
   ```python
   class CtrDnn(nn.HybridBlock):
       def __init__(self, sparse_feature_number, sparse_feature_dim,
                    dense_feature_dim, num_field, layer_sizes, **kwargs):
           super(CtrDnn, self).__init__(**kwargs)
           self.sparse_feature_number = sparse_feature_number
           self.sparse_feature_dim = sparse_feature_dim
   
           sizes = [sparse_feature_dim * num_field +
                    dense_feature_dim] + layer_sizes
   
           self.embedding = nn.Embedding(
               sparse_feature_number, sparse_feature_dim, sparse_grad=True)
   
           self.dense1 = nn.Dense(in_units=sizes[0],
                                  units=sizes[1],
                                  activation='relu',
                                  weight_initializer=mx.init.Normal(1.0 / math.sqrt(sizes[1])))
   
           self.dense2 = nn.Dense(in_units=sizes[1],
                                  units=sizes[2],
                                  activation='relu',
                                  weight_initializer=mx.init.Normal(1.0 / math.sqrt(sizes[2])))
   
           self.dense3 = nn.Dense(in_units=sizes[2],
                                  units=sizes[3],
                                  activation='relu',
                                  weight_initializer=mx.init.Normal(1.0 / math.sqrt(sizes[3])))
   
           self.dense4 = nn.Dense(in_units=layer_sizes[-1],
                                  units=2,
                                  weight_initializer=mx.init.Normal(1.0 / math.sqrt(sizes[-1])))
   
       def hybrid_forward(self, F, sparse_inputs, dense_inputs):
           sparse_embs = []
           for s_input in sparse_inputs:
               emb = self.embedding(s_input)
               sparse_embs.append(emb)
   
           for i in range(len(sparse_embs)):
               sparse_embs[i] = F.reshape(
                   sparse_embs[i], (-1, self.sparse_feature_dim))
   
           dnn_input = F.concat(sparse_embs[0],
                                sparse_embs[1],
                                sparse_embs[2],
                                sparse_embs[3],
                                sparse_embs[4],
                                sparse_embs[5],
                                sparse_embs[6],
                                sparse_embs[7],
                                sparse_embs[8],
                                sparse_embs[9],
                                sparse_embs[10],
                                sparse_embs[11],
                                sparse_embs[12],
                                sparse_embs[13],
                                sparse_embs[14],
                                sparse_embs[15],
                                sparse_embs[16],
                                sparse_embs[17],
                                sparse_embs[18],
                                sparse_embs[19],
                                sparse_embs[20],
                                sparse_embs[21],
                                sparse_embs[22],
                                sparse_embs[23],
                                sparse_embs[24],
                                sparse_embs[25],
                                dense_inputs,
                                dim=1)
           layer1 = self.dense1(dnn_input)
           layer2 = self.dense2(layer1)
           layer3 = self.dense3(layer2)
           dnn_output = self.dense4(layer3)
   
           return dnn_output
   ```
   
    it works well on single machine,but when I try distributed train with  kv("dist_async") , it throw error
   
   ### Error Message
   (Paste the complete error message. Please also include stack trace by setting environment variable `DMLC_LOG_STACK_TRACE_DEPTH=100` before running your script.)
   
   ```bash
   [06:21:15] src/van.cc:310: Bind to role=worker, ip=192.168.1.2, port=35008, is_recovery=0
   [06:21:15] src/van.cc:257: W[9] is connected to others
   2021-02-24 06:21:15,768 - INFO - File list: ['./train_data/part-0']
   2021-02-24 06:21:15,775 - INFO - File: ./train_data/part-0 has 20000 examples
   2021-02-24 06:21:15,775 - INFO - Total example: 20000
   2021-02-24 06:21:16,346 - INFO - Load Data in memory finish, using time: 0.5777103900909424
   2021-02-24 06:21:16,347 - INFO - Epoch 0 training begin
   [06:21:16] src/operator/tensor/./.././../common/utils.h:473:
   Storage fallback detected:
   Copy from row_sparse storage type on cpu to default storage type on cpu.
   A temporary ndarray with default storage type will be generated in order to perform the copy. This does not affect the correctness of the programme. You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning.
   terminate called after throwing an instance of 'dmlc::Error'
     what():  [06:21:16] /home/centos/mxnet/3rdparty/ps-lite/include/ps/kv_app.h:697: Check failed: lens->size() == keys.size() (7626 vs. 2) :
   ```
   
   ## To Reproduce
   (If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
   
   Here is my train.py
   
   ```python
   class Train(object):
   
       def run(self):
           # hyper parameters
           epochs = 1
           batch_size = 1000
           sparse_feature_number = 1000001
           sparse_feature_dim = 10
           dense_feature_dim = 13
           num_field = 26
           layer_sizes = [400, 400, 400]
           train_data_path = "./train_data"
           print_step = 5
           distributed_train = True
           cpu_num = int(os.getenv("CPU_NUM", 1))
   
           # create network
           ctx = mx.cpu()
           net = CtrDnn(sparse_feature_number, sparse_feature_dim,
                        dense_feature_dim, num_field, layer_sizes)
           net.initialize(ctx=ctx)
           # net.hybridize()
   
           self.loss = gluon.loss.SoftmaxCrossEntropyLoss()
   
           if distributed_train:
               self.store = kv.create('dist_async')
           else:
               self.store = kv.create('local')
   
           # Load the training data
           reader_start_time = time.time()
   
           file_list = self.get_file_list(train_data_path, distributed_train)
           reader = Reader()
           dataset = reader.load_criteo_dataset(file_list)
           train_data = gluon.data.DataLoader(
               dataset, batch_size, num_workers=cpu_num, last_batch="discard")
           reader_end_time = time.time()
           logger.info("Load Data in memory finish, using time: {}".format(
               reader_end_time - reader_start_time))
   
           if distributed_train:
               trainer = gluon.Trainer(net.collect_params(), 'adam', {
                   'learning_rate': 0.0001, 'lazy_update': True}, kvstore=self.store, update_on_kvstore=True)
           else:
               trainer = gluon.Trainer(net.collect_params(), 'adam', {
                                       'learning_rate': 0.0001}, kvstore=self.store)
   
           for epoch in range(epochs):
               logger.info("Epoch {} training begin".format(epoch))
               epoch_start_time = time.time()
   
               batch_id = 1
               train_run_cost = 0.0
               total_examples = 0
               self.global_score = None
               self.global_label = None
   
               for batch in train_data:
                   train_start = time.time()
                   loss_value = self.train_batch(
                       batch, ctx, net, trainer)
   
                   train_run_cost += (time.time() - train_start)
                   total_examples += batch_size
   
                   batch_id += 1
                   if batch_id % print_step == 0:
                       metric_start = time.time()
                       fpr, tpr, _ = metrics.roc_curve(
                           list(self.global_lable.asnumpy()), list(self.global_score.asnumpy()))
                       auc_value = metrics.auc(fpr, tpr)
                       train_run_cost += (time.time() - metric_start)
   
                       metrics_string = "auc: {}, loss: {}".format(
                           auc_value, loss_value)
                       profiler_string = ""
                       profiler_string += "using_time: {} sec ".format(
                           train_run_cost)
                       profiler_string += "avg_batch_cost: {} sec, ".format(
                           format((train_run_cost) / print_step, '.5f'))
                       profiler_string += " ips: {} example/sec ".format(
                           format(total_examples / (train_run_cost), '.5f'))
                       logger.info("Epoch: {}, Batch: {}, {} {}".format(
                           epoch, batch_id, metrics_string, profiler_string))
                       train_run_cost = 0.0
                       total_examples = 0
   
               epoch_end_time = time.time()
               logger.info(
                   "Epoch: {}, using time {} second,".format(
                       epoch, epoch_end_time - epoch_start_time))
   
       def calc_auc(self, label, output):
           output_exp = output.exp()
           paratition = output_exp.sum(axis=1, keepdims=True)
           score = output_exp / paratition
           score = nd.slice_axis(score, axis=1, begin=1, end=2)
   
           if self.global_score is None:
               # for first time
               self.global_score = score
               self.global_lable = label
           else:
               self.global_score = nd.concat(self.global_score, score, dim=0)
               self.global_lable = nd.concat(self.global_lable, label, dim=0)
   
       def forward_backward(self, network, label, sparse_input, dense_input):
           # Ask autograd to remember the forward pass
           with autograd.record():
               output = network(sparse_input, dense_input)
               losses = self.loss(output, label)
               self.calc_auc(label, output)
   
           for l in [losses]:
               l.backward()
   
           return np.mean(losses.as_np_ndarray())
   
       def train_batch(self, batch_list, context, network, gluon_trainer):
           label = batch_list[0]
           # label = gluon.utils.split_and_load(label, context)
   
           sparse_input = batch_list[1:-1]
   
           dense_input = batch_list[-1]
   
           # Run the forward and backward pass
           loss = self.forward_backward(network, label, sparse_input, dense_input)
   
           # Update the parameters
           this_batch_size = batch_list[0].shape[0]
           gluon_trainer.step(this_batch_size)
   
           return loss
   
       def get_example_num(self, file_list):
           count = 0
           for f in file_list:
               last_count = count
               for _, _ in enumerate(open(f, 'r')):
                   count += 1
               logger.info("File: %s has %s examples" % (f, count - last_count))
           logger.info("Total example: %s" % count)
           return count
   
       def get_file_list(self, data_path, split_file_list=False):
           assert os.path.exists(data_path)
           file_list = [data_path + "/%s" % x for x in os.listdir(data_path)]
           file_list.sort()
           if split_file_list:
               file_list = self.get_file_shard(file_list)
           logger.info("File list: {}".format(file_list))
           self.get_example_num(file_list)
           return file_list
   
       def get_file_shard(self, files):
           if not isinstance(files, list):
               raise TypeError("files should be a list of file need to be read.")
   
           trainer_id = self.store.rank
           trainers = self.store.num_workers
   
           remainder = len(files) % trainers
           blocksize = int(len(files) / trainers)
   
           blocks = [blocksize] * trainers
           for i in range(remainder):
               blocks[i] += 1
   
           trainer_files = [[]] * trainers
           begin = 0
           for i in range(trainers):
               trainer_files[i] = files[begin:begin + blocks[i]]
               begin += blocks[i]
   
           return trainer_files[trainer_id]
   ```
   
   ### Steps to reproduce
   (Paste the commands you ran that produced the error.)
   
   I uploaded the complete code here:https://github.com/MrChengmo/MxnetPS-Example/tree/main/dnn
   
   1.  Train with single machine:`python -u train.py`
   2. Single machine simulation of distributed operation: `bash local_cluster.sh`
   
   ## What have you tried to solve it?
   
   1. it works well when set  `distributed_train=False`
   2. it can't work when use `net.hybridize()`
   
   ## Environment
   
   I run my code in docker deepo, with mxnet version == 1.7.0
   ```bash
   docker pull ufoym/deepo:cpu
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] github-actions[bot] commented on issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949#issuecomment-784836330


   Welcome to Apache MXNet (incubating)! We are on a mission to democratize AI, and we are glad that you are contributing to it by opening this issue.
   Please make sure to include all the relevant context, and one of the @apache/mxnet-committers will be here shortly.
   If you are interested in contributing to our project, let us know! Also, be sure to check out our guide on [contributing to MXNet](https://mxnet.apache.org/community/contribute) and our [development guides wiki](https://cwiki.apache.org/confluence/display/MXNET/Developments).


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] MrChengmo edited a comment on issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

Posted by GitBox <gi...@apache.org>.
MrChengmo edited a comment on issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949#issuecomment-785559171


   > Do you observe the Storage fallback also on single machine? Or does it only occur in the distributed setting?
   
   - Set `kv_store="local"` & `nn.Embedding(sparse_gard)`,  works well on single machine (distributed_traing=False), use command `python -u train.py `.
   - Set `kv_store="dist_async"` & `nn.Embedding(sparse_gard)`,  set `distributed_traing=True`, use command `bash local_cluster.sh` to simulate distributed training, throw error


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] MrChengmo edited a comment on issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

Posted by GitBox <gi...@apache.org>.
MrChengmo edited a comment on issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949#issuecomment-785559171


   > Do you observe the Storage fallback also on single machine? Or does it only occur in the distributed setting?
   
   - Set `kv_store="local"` & `nn.Embedding(sparse_gard)`,  works well on single machine (distributed_training=False), use command `python -u train.py `.
   - Set `kv_store="dist_async"` & `nn.Embedding(sparse_gard)`,  set `distributed_traing=True`, use command `bash local_cluster.sh` to simulate distributed training, throw error


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] MrChengmo commented on issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

Posted by GitBox <gi...@apache.org>.
MrChengmo commented on issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949#issuecomment-785559171


   > Do you observe the Storage fallback also on single machine? Or does it only occur in the distributed setting?
   
   - Set `kv_store="local"` & `nn.Embedding(sparse_gard)`,  works well on single machine (distributed_traing=False), use command `python -u train.py `.
   - Set `kv_store="dist_async"` & `nn.Embedding(sparse_gard)`,  set `distributed_traing=True`, use command `bash local_cluster` to simulate distributed training, throw error


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org


[GitHub] [incubator-mxnet] leezu commented on issue #19949: DistributeTraining throw "dmlc::Error" when using nn.Embedding(sparse_grad=True)

Posted by GitBox <gi...@apache.org>.
leezu commented on issue #19949:
URL: https://github.com/apache/incubator-mxnet/issues/19949#issuecomment-785248093


   Do you observe the Storage fallback also on single machine? Or does it only occur in the distributed setting?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org