You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/07 22:13:55 UTC

[GitHub] eric-haibin-lin closed pull request #10533: [MXNET-314] Support Integer Type parsing in CSVIter

eric-haibin-lin closed pull request #10533: [MXNET-314] Support Integer Type parsing in CSVIter
URL: https://github.com/apache/incubator-mxnet/pull/10533
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index e9446f5a53c..d26d9e7982b 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit e9446f5a53cf5e61273deff7ce814093d2791766
+Subproject commit d26d9e7982b233d4aa105ae084fbecc500d254ff
diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc
index a9e650b6387..ca3f042f45a 100644
--- a/src/io/iter_csv.cc
+++ b/src/io/iter_csv.cc
@@ -57,23 +57,54 @@ struct CSVIterParam : public dmlc::Parameter<CSVIterParam> {
   }
 };
 
-class CSVIter: public IIterator<DataInst> {
+class CSVIterBase: public IIterator<DataInst> {
  public:
-  CSVIter() {
+  CSVIterBase() {
     out_.data.resize(2);
   }
-  virtual ~CSVIter() {}
+  virtual ~CSVIterBase() {}
+
+  // initialize iterator loads data in
+  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
+  /*! \brief reset the iterator */
+  virtual void BeforeFirst(void) = 0;
+  /*! \brief move to next item */
+  virtual bool Next(void) = 0;
+  /*! \brief get current data */
+  virtual const DataInst &Value(void) const {
+    return out_;
+  }
+
+ protected:
+  CSVIterParam param_;
+
+  DataInst out_;
+
+  // internal instance counter
+  unsigned inst_counter_{0};
+  // at end
+  bool end_{false};
+
+  // label parser
+  size_t label_ptr_{0}, label_size_{0};
+  size_t data_ptr_{0}, data_size_{0};
+};
 
+template <typename DType>
+class CSVIterTyped: public CSVIterBase {
+ public:
+  virtual ~CSVIterTyped() {}
   // intialize iterator loads data in
   virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
     param_.InitAllowUnknown(kwargs);
-    data_parser_.reset(dmlc::Parser<uint32_t>::Create(param_.data_csv.c_str(), 0, 1, "csv"));
+    data_parser_.reset(dmlc::Parser<uint32_t, DType>::Create(param_.data_csv.c_str(), 0, 1, "csv"));
     if (param_.label_csv != "NULL") {
-      label_parser_.reset(dmlc::Parser<uint32_t>::Create(param_.label_csv.c_str(), 0, 1, "csv"));
+      label_parser_.reset(
+        dmlc::Parser<uint32_t, DType>::Create(param_.label_csv.c_str(), 0, 1, "csv"));
     } else {
       dummy_label.set_pad(false);
       dummy_label.Resize(mshadow::Shape1(1));
-      dummy_label = 0.0f;
+      dummy_label = 0;
     }
   }
 
@@ -116,33 +147,63 @@ class CSVIter: public IIterator<DataInst> {
     return true;
   }
 
-  virtual const DataInst &Value(void) const {
-    return out_;
-  }
-
  private:
-  inline TBlob AsTBlob(const dmlc::Row<uint32_t>& row, const TShape& shape) {
+  inline TBlob AsTBlob(const dmlc::Row<uint32_t, DType>& row, const TShape& shape) {
     CHECK_EQ(row.length, shape.Size())
         << "The data size in CSV do not match size of shape: "
         << "specified shape=" << shape << ", the csv row-length=" << row.length;
-    const real_t* ptr = row.value;
-    return TBlob((real_t*)ptr, shape, cpu::kDevMask, 0);  // NOLINT(*)
+    const DType* ptr = row.value;
+    return TBlob((DType*)ptr, shape, cpu::kDevMask, 0);  // NOLINT(*)
+  }
+  // dummy label
+  mshadow::TensorContainer<cpu, 1, DType> dummy_label;
+  std::unique_ptr<dmlc::Parser<uint32_t, DType> > label_parser_;
+  std::unique_ptr<dmlc::Parser<uint32_t, DType> > data_parser_;
+};
+
+class CSVIter: public IIterator<DataInst> {
+ public:
+  CSVIter() {}
+  virtual ~CSVIter() {}
+
+  // intialize iterator loads data in
+  virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
+    param_.InitAllowUnknown(kwargs);
+    bool dtype_has_value = false;
+    int target_dtype = -1;
+    for (const auto& arg : kwargs) {
+      if (arg.first == "dtype") {
+        dtype_has_value = true;
+        if (arg.second == "int32" || arg.second == "float32") {
+          target_dtype = (arg.second == "int32") ? mshadow::kInt32 : mshadow::kFloat32;
+        } else {
+          CHECK(false) << arg.second << " is not supported for CSVIter";
+        }
+      }
+    }
+    if (dtype_has_value && target_dtype == mshadow::kInt32) {
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int>()));
+    } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) {
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<float>()));
+    }
+    iterator_->Init(kwargs);
+  }
+
+  virtual void BeforeFirst() {
+    iterator_->BeforeFirst();
+  }
+
+  virtual bool Next() {
+    return iterator_->Next();
   }
 
+  virtual const DataInst &Value(void) const {
+    return iterator_->Value();
+  }
+
+ private:
   CSVIterParam param_;
-  // output instance
-  DataInst out_;
-  // internal instance counter
-  unsigned inst_counter_{0};
-  // at end
-  bool end_{false};
-  // dummy label
-  mshadow::TensorContainer<cpu, 1, real_t> dummy_label;
-  // label parser
-  size_t label_ptr_{0}, label_size_{0};
-  size_t data_ptr_{0}, data_size_{0};
-  std::unique_ptr<dmlc::Parser<uint32_t> > label_parser_;
-  std::unique_ptr<dmlc::Parser<uint32_t> > data_parser_;
+  std::unique_ptr<CSVIterBase> iterator_;
 };
 
 
@@ -167,6 +228,10 @@ If ``data_csv = 'data/'`` is set, then all the files in this directory will be r
 
 ``reset()`` is expected to be called only after a complete pass of data.
 
+By default, the CSVIter parses all entries in the data file as float32 data type,
+if `dtype` argument is set to be 'int32' then CSVIter will parse all entries in the file
+as int32 data type.
+
 Examples::
 
   // Contents of CSV file ``data/data.csv``.
@@ -220,6 +285,20 @@ Examples::
   [2.  3.  4.]
   [3.  4.  5.]]
 
+  // Creates a 'CSVIter' with `dtype`='int32'
+  CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
+  batch_size = 3, round_batch=False, dtype='int32')
+
+  // Contents of two batches read from the above iterator in both passes, after calling
+  // `reset` method before second pass, is as follows:
+  [[1  2  3]
+  [2  3  4]
+  [3  4  5]]
+
+  [[4  5  6]
+  [2  3  4]
+  [3  4  5]]
+
 )code" ADD_FILELINE)
 .add_arguments(CSVIterParam::__FIELDS__())
 .add_arguments(BatchParam::__FIELDS__())
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index a54cb9233a7..7e6ef1af5ab 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -293,24 +293,30 @@ def test_DataBatch():
 
 
 def test_CSVIter():
-    def check_CSVIter_synthetic():
+    def check_CSVIter_synthetic(dtype='float32'):
         cwd = os.getcwd()
         data_path = os.path.join(cwd, 'data.t')
         label_path = os.path.join(cwd, 'label.t')
+        entry_str = '1'
+        if dtype is 'int32':
+            entry_str = '200000001'
         with open(data_path, 'w') as fout:
             for i in range(1000):
-                fout.write(','.join(['1' for _ in range(8*8)]) + '\n')
+                fout.write(','.join([entry_str for _ in range(8*8)]) + '\n')
         with open(label_path, 'w') as fout:
             for i in range(1000):
                 fout.write('0\n')
 
         data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8,8),
-                                   label_csv=label_path, batch_size=100)
-        expected = mx.nd.ones((100, 8, 8))
+                                   label_csv=label_path, batch_size=100, dtype=dtype)
+        expected = mx.nd.ones((100, 8, 8), dtype=dtype) * int(entry_str)
         for batch in iter(data_train):
-            assert_almost_equal(data_train.getdata().asnumpy(), expected.asnumpy())
+            data_batch = data_train.getdata()
+            assert_almost_equal(data_batch.asnumpy(), expected.asnumpy())
+            assert data_batch.asnumpy().dtype == expected.asnumpy().dtype
 
-    check_CSVIter_synthetic()
+    for dtype in ['int32', 'float32']:
+        check_CSVIter_synthetic(dtype=dtype)
 
 if __name__ == "__main__":
     test_NDArrayIter()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 96dd0b2d63c..ba66d8b6e9d 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -2491,7 +2491,7 @@ def test_infer_type(dtype):
                                                 names=['a', 'b'])
             raise AssertionError(msg)
 
-    for dtype in ['float16', 'float32', 'float64']:
+    for dtype in ['float16', 'float32']:
         test_infer_type(dtype)
         unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False, dtype = dtype)
         unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False, dtype = dtype)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services