You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/29 19:02:28 UTC

[GitHub] szha closed pull request #11446: Support int64 data type in CSVIter

szha closed pull request #11446: Support int64 data type in CSVIter
URL: https://github.com/apache/incubator-mxnet/pull/11446
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index dadcd97fdce..649be18a8c5 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit dadcd97fdceb5f395e963b2a637f6ed377f59fc4
+Subproject commit 649be18a8c55c48517861d67158a45dec54992ee
diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h
index 56822888a44..8580ff8f9f9 100644
--- a/src/io/image_iter_common.h
+++ b/src/io/image_iter_common.h
@@ -348,6 +348,7 @@ struct PrefetcherParam : public dmlc::Parameter<PrefetcherParam> {
       .add_enum("float32", mshadow::kFloat32)
       .add_enum("float64", mshadow::kFloat64)
       .add_enum("float16", mshadow::kFloat16)
+      .add_enum("int64", mshadow::kInt64)
       .add_enum("int32", mshadow::kInt32)
       .add_enum("uint8", mshadow::kUint8)
       .set_default(dmlc::optional<int>())
diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc
index ca3f042f45a..5fd149535be 100644
--- a/src/io/iter_csv.cc
+++ b/src/io/iter_csv.cc
@@ -174,15 +174,21 @@ class CSVIter: public IIterator<DataInst> {
     for (const auto& arg : kwargs) {
       if (arg.first == "dtype") {
         dtype_has_value = true;
-        if (arg.second == "int32" || arg.second == "float32") {
-          target_dtype = (arg.second == "int32") ? mshadow::kInt32 : mshadow::kFloat32;
+        if (arg.second == "int32") {
+          target_dtype = mshadow::kInt32;
+        } else if (arg.second == "int64") {
+          target_dtype = mshadow::kInt64;
+        } else if (arg.second == "float32") {
+          target_dtype = mshadow::kFloat32;
         } else {
           CHECK(false) << arg.second << " is not supported for CSVIter";
         }
       }
     }
     if (dtype_has_value && target_dtype == mshadow::kInt32) {
-      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int>()));
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int32_t>()));
+    } else if (dtype_has_value && target_dtype == mshadow::kInt64) {
+      iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int64_t>()));
     } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) {
       iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<float>()));
     }
@@ -229,8 +235,8 @@ If ``data_csv = 'data/'`` is set, then all the files in this directory will be r
 ``reset()`` is expected to be called only after a complete pass of data.
 
 By default, the CSVIter parses all entries in the data file as float32 data type,
-if `dtype` argument is set to be 'int32' then CSVIter will parse all entries in the file
-as int32 data type.
+if `dtype` argument is set to be 'int32' or 'int64' then CSVIter will parse all entries in the file
+as int32 or int64 data type accordingly.
 
 Examples::
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services