You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/06/29 19:02:34 UTC

[incubator-mxnet] branch master updated: support int64 data type in CSVIter (#11446)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a62200  support int64 data type in CSVIter (#11446)
3a62200 is described below

commit 3a62200799d2d75039eb38e186d94361c17d060c
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Fri Jun 29 15:02:25 2018 -0400

    support int64 data type in CSVIter (#11446)
---
 3rdparty/dmlc-core         |  2 +-
 src/io/image_iter_common.h |  1 +
 src/io/iter_csv.cc         | 16 +++++++++++-----
 3 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index dadcd97..649be18 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit dadcd97fdceb5f395e963b2a637f6ed377f59fc4
+Subproject commit 649be18a8c55c48517861d67158a45dec54992ee
diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h
index 5682288..8580ff8 100644
--- a/src/io/image_iter_common.h
+++ b/src/io/image_iter_common.h
@@ -348,6 +348,7 @@ struct PrefetcherParam : public dmlc::Parameter<PrefetcherParam> {
       .add_enum("float32", mshadow::kFloat32)
       .add_enum("float64", mshadow::kFloat64)
       .add_enum("float16", mshadow::kFloat16)
+      .add_enum("int64", mshadow::kInt64)
       .add_enum("int32", mshadow::kInt32)
       .add_enum("uint8", mshadow::kUint8)
       .set_default(dmlc::optional<int>())
diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc
index ca3f042..5fd1495 100644
--- a/src/io/iter_csv.cc
+++ b/src/io/iter_csv.cc
@@ -174,15 +174,21 @@ class CSVIter: public IIterator<DataInst> {
     for (const auto& arg : kwargs) {
       if (arg.first == "dtype") {
         dtype_has_value = true;
-        if (arg.second == "int32" || arg.second == "float32") {
-          target_dtype = (arg.second == "int32") ? mshadow::kInt32 : mshadow::kFloat32;
+        if (arg.second == "int32") {
+          target_dtype = mshadow::kInt32;
+        } else if (arg.second == "int64") {
+          target_dtype = mshadow::kInt64;
+        } else if (arg.second == "float32") {
+          target_dtype = mshadow::kFloat32;
         } else {
           CHECK(false) << arg.second << " is not supported for CSVIter";
         }
       }
     }
     if (dtype_has_value && target_dtype == mshadow::kInt32) {
-      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int>()));
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int32_t>()));
+    } else if (dtype_has_value && target_dtype == mshadow::kInt64) {
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int64_t>()));
     } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) {
       iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<float>()));
     }
@@ -229,8 +235,8 @@ If ``data_csv = 'data/'`` is set, then all the files in this directory will be r
 ``reset()`` is expected to be called only after a complete pass of data.
 
 By default, the CSVIter parses all entries in the data file as float32 data type,
-if `dtype` argument is set to be 'int32' then CSVIter will parse all entries in the file
-as int32 data type.
+if `dtype` argument is set to be 'int32' or 'int64' then CSVIter will parse all entries in the file
+as int32 or int64 data type accordingly.
 
 Examples::