You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/15 18:33:13 UTC

[GitHub] GSanchis opened a new issue #8669: module.Module and CSVIter

GSanchis opened a new issue #8669: module.Module and CSVIter
URL: https://github.com/apache/incubator-mxnet/issues/8669
 
 
   ## Description
   
   I'm trying to use module.Module to build a recommender system. Using NDArrayIter makes memory requirements skyrocket. But I have been unable to use CSVIter for the same purpose. I'm using MXNet under Python. I followed [this](https://www.oreilly.com/ideas/deep-matrix-factorization-using-apache-mxnet) tutorial to build the first recommender, but I don't seem to be able to move forward.
   
   ## Minimum reproducible example
   
   ```
   import mxnet as mx
   import numpy
   import time
   
   file="f1.csv"
   l,c,v = numpy.loadtxt(file, delimiter=',',dtype='int').T
   # This should not be needed, but i need the l.max() operation, which I don't seem to be able to have in CSVIter
   
   user = mx.symbol.Variable("user")
   user = mx.symbol.Embedding(data=user, input_dim=l.max(), output_dim=10)
   movie = mx.symbol.Variable("movie")
   movie = mx.symbol.Embedding(data=movie, input_dim=c.max(), output_dim=10)
   y_true = mx.symbol.Variable("softmax_label")
   nn = mx.symbol.concat(user, movie)
   nn = mx.symbol.flatten(nn)
   nn = mx.symbol.FullyConnected(data=nn, num_hidden=64)
   nn = mx.symbol.Activation(data=nn, act_type='relu')
   nn = mx.symbol.FullyConnected(data=nn, num_hidden=1)
   y_pred = mx.symbol.LinearRegressionOutput(data=nn, label=y_true)
   
   tritems=int(len(l)*80/100)
   X_train=mx.io.NDArrayIter({'user': l[:tritems], 'movie': c[:tritems]}, label=v[:tritems], batch_size=10000)
   X_eval=mx.io.NDArrayIter({'user': l[tritems:], 'movie': c[tritems:]}, label=v[tritems:], batch_size=10000)
   X_all=mx.io.NDArrayIter({'user': l, 'movie': c}, label=v, batch_size=10000)
   model = mx.module.Module(context=mx.cpu(0), data_names=['user', 'movie'], symbol=y_pred)
   model.fit(X_train, num_epoch=5, optimizer='adam', optimizer_params=(('learning_rate', 0.001),), eval_metric='mse', eval_data=X_eval)
   ```
   
   Works awesome!
   
   Then I try to replace the mx.io.NDArrayIter by a CSVIter with
   ```
   f1="tmp1"    # contains first two columns of f1.csv
   f2="tmp2"    # contains the third column of f1.csv
   CSVIter = mx.io.CSVIter(data_csv=f1, data_shape=(2,), label_csv=f2, label_shape=(1,), batch_size=1000)
   model = mx.module.Module(context=mx.cpu(0), data_names=['user', 'movie'], symbol=y_pred)
   model.fit(CSVIter, num_epoch=5, optimizer='adam', optimizer_params=(('learning_rate', 0.001),))
   ```
   And I get the following error
   ```Traceback (most recent call last):
     File "<stdin>", line 1, in <module>
     File "/home/german/anaconda3/lib/python3.6/site-packages/mxnet/module/base_module.py", line 460, in fit
       for_training=True, force_rebind=force_rebind)
     File "/home/german/anaconda3/lib/python3.6/site-packages/mxnet/module/module.py", line 400, in bind
       self.data_names, self.label_names, data_shapes, label_shapes)
     File "/home/german/anaconda3/lib/python3.6/site-packages/mxnet/module/base_module.py", line 71, in _parse_data_desc
       _check_names_match(data_names, data_shapes, 'data', True)
     File "/home/german/anaconda3/lib/python3.6/site-packages/mxnet/module/base_module.py", line 63, in _check_names_match
       raise ValueError(msg)
   ValueError: Data provided by data_shapes don't match names specified by data_names ([DataDesc[data,(1000, 2),<class 'numpy.float32'>,NCHW]] vs. ['user', 'movie'])
   ```
   
   ## What have you tried to solve it?
   
   I have tried to reassign the `CSVIter.provide_label` and `CSVIter.provide_data` fields with no luck... I tried adding `name='data'` to all lines that have `nn=` with no luck... I've tried removing `data_names` from the `mx.module.Module` line with no luck... I might have been around 6 hours googling around and looking at the mx.io API, but I'm currently out of ideas.
   
   Any help is welcome!!

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services